diff --git a/configs/multi_mo_multi_task.yaml b/configs/multi_mo_multi_task.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ef0b7afa54b90c114b889ae98a4f46e9a45de334 --- /dev/null +++ b/configs/multi_mo_multi_task.yaml @@ -0,0 +1,156 @@ +train_dataset: + dataset: + name: paired-image-folders-multi-task + args: +# root_path_1: ./SAM_DATA_UNIFY/Overall_Update/split_image +# root_path_1: ./SAM_DATA_UNIFY2/OVERALL/split_image +# root_path_1: ./SAM_DATA_UNIFY2/ISAID/split_image +# root_path_1: [{'ISAID': './SAM_DATA_UNIFY2/ISAID/split_image', 'WHU': './SAM_DATA_UNIFY2/WHU-OPT/split_images'}] +# root_path_1: [{'Decoder1': "/workspace/SAM_DATA_UNIFY3/Decoder1/split_image/", 'Decoder2': "/workspace/SAM_DATA_UNIFY3/Decoder2/split_image/"}] + root_path_1: [{'Decoder1': "/workspace/SAM_DATA_UNIFY4/Decoder1/image/", 'Decoder2': "/workspace/SAM_DATA_UNIFY4/Decoder2/image/"}] +# root_path_2: ./SAM_DATA_UNIFY/Overall_Update/split_gt +# root_path_2: ./SAM_DATA_UNIFY2/OVERALL/split_gt +# root_path_2: ./SAM_DATA_UNIFY2/ISAID/split_gt +# root_path_2: [{'ISAID': './SAM_DATA_UNIFY2/ISAID/split_gt', 'WHU': './SAM_DATA_UNIFY2/WHU-OPT/split_gt'}] +# root_path_2: [{'Decoder1': "/workspace/SAM_DATA_UNIFY3/Decoder1/split_gt/", 'Decoder2': "/workspace/SAM_DATA_UNIFY3/Decoder2/split_gt/"}] + root_path_2: [{'Decoder1': "/workspace/SAM_DATA_UNIFY4/Decoder1/gt/", 'Decoder2': "/workspace/SAM_DATA_UNIFY4/Decoder2/gt/"}] + cache: nones + split_key: train + wrapper: + name: train_multi_task + args: + inp_size: 1024 + augment: false +# batch_size: 2 + batch_size: 2 + +val_dataset: + dataset: + name: paired-image-folders-multi-task + args: +# root_path_1: ./SAM_DATA_UNIFY2/OVERALL/split_image +# root_path_1: [{'ISAID': './SAM_DATA_UNIFY2/ISAID/split_image', 'WHU': './SAM_DATA_UNIFY2/WHU-OPT/split_images'}] +# root_path_1: [{'Decoder1': "/workspace/SAM_DATA_UNIFY3/Decoder1/split_image/", 'Decoder2': "/workspace/SAM_DATA_UNIFY3/Decoder2/split_image/"}] + root_path_1: [{'Decoder1': "/workspace/SAM_DATA_UNIFY4/Decoder1/image/", 'Decoder2': "/workspace/SAM_DATA_UNIFY4/Decoder2/image/"}] +# root_path_2: ./SAM_DATA_UNIFY2/OVERALL/split_gt +# root_path_2: [{'ISAID': './SAM_DATA_UNIFY2/ISAID/split_gt', 'WHU': './SAM_DATA_UNIFY2/WHU-OPT/split_gt'}] +# root_path_2: [{'Decoder1': "/workspace/SAM_DATA_UNIFY3/Decoder1/split_gt/", 'Decoder2': "/workspace/SAM_DATA_UNIFY3/Decoder2/split_gt/"}] + root_path_2: [{'Decoder1': "/workspace/SAM_DATA_UNIFY4/Decoder1/gt/", 'Decoder2': "/workspace/SAM_DATA_UNIFY4/Decoder2/gt/"}] + cache: none + split_key: test + wrapper: + name: val_multi_task + args: + inp_size: 1024 +# batch_size: 2 + batch_size: 1 + +test_dataset: + dataset: + name: paired-image-folders + args: + +# root_path_1: ./SAM_DATA_UNIFY3/ISAID/split_image +# root_path_1: ./SAM_DATA_UNIFY3/GANFEN/split_image +# root_path_1: ./SAM_DATA_UNIFY3/SAR2020/split_image_ov500 +# root_path_1: ./SAM_DATA_UNIFY3/ISAID/split_image +# root_path_1: ./SAM_DATA_UNIFY4/SAR2020/split_image_ov500 +# root_path_1: ./SAM_DATA_UNIFY4/GAOFEN/split_image +# root_path_1: ./SAM_DATA_UNIFY4/Vaihingen/image1 +# root_path_1: ./SAM_DATA_UNIFY4/SAR2020/split_image_ov500 +# root_path_1: ./SAM_DATA_UNIFY4/Potsdam/image1 +# root_path_1: ./SAM_DATA_UNIFY4/whu-opt-sar/image_sar + root_path_1: /workspace/AIService/FoundationModel/sam_adapter_01/TwoDecoder_data/Prompt_GUOLV_Data/prompt_all1/image + +# root_path_2: ./SAM_DATA_UNIFY3/ISAID/split_gt +# root_path_2: ./SAM_DATA_UNIFY3/GANFEN/gt_decoder1 +# root_path_2: ./SAM_DATA_UNIFY3/GANFEN/gt_decoder2 +# root_path_2: ./SAM_DATA_UNIFY3/SAR2020/gt_decoder2 +# root_path_2: ./SAM_DATA_UNIFY3/ISAID/split_gt +# root_path_2: ./SAM_DATA_UNIFY4/SAR2020/gt_decoder2 +# root_path_2: ./SAM_DATA_UNIFY4/GAOFEN/gt_decoder1_update +# root_path_2: ./SAM_DATA_UNIFY4/Vaihingen/gt2 +# root_path_2: ./SAM_DATA_UNIFY4/Potsdam/gt1 +# root_path_2: ./SAM_DATA_UNIFY4/SAR2020/gt_decoder2 + root_path_2: /workspace/AIService/FoundationModel/sam_adapter_01/TwoDecoder_data/Prompt_GUOLV_Data/prompt_all1/gt +# root_path_2: ./SAM_DATA_UNIFY4/whu-opt-sar/gt_sar + cache: none + split_key: test + wrapper: + name: val + args: +# inp_size: 1024 + inp_size: 1024 + batch_size: 1 + +#eval_type: cod +eval_type: f1 +#sam_checkpoint: ./pretrained/sam_vit_l_0b3195.pth +sam_checkpoint: sam_vit_h_4b8939.pth +data_norm: + inp: + sub: + - 0.5 + div: + - 0.5 + gt: + sub: + - 0.5 + div: + - 0.5 + gt_rgb: + sub: + - 0.5 + div: + - 0.5 +model: + name: sam + args: + inp_size: 1024 +# loss: iou + loss: cr + encoder_mode: + name: sam + img_size: 1024 + mlp_ratio: 4 + patch_size: 16 + qkv_bias: true + use_rel_pos: true + window_size: 14 + out_chans: 256 + scale_factor: 32 + input_type: fft + freq_nums: 0.25 + prompt_type: highpass + prompt_embed_dim: 256 + tuning_stage: 1234 + handcrafted_tune: true + embedding_tune: true + adaptor: adaptor + embed_dim: 1280 + depth: 32 + num_heads: 16 + global_attn_indexes: + - 7 + - 15 + - 23 + - 31 +optimizer: + name: adamw + args: +# lr: 0.0002 +# lr: 0.00002 + lr: 0.00008 +lr_min: 1.0e-8 +#epoch_max: 20 +epoch_max: 100 + +multi_step_lr: + milestones: + - 1 + gamma: 0.1 +epoch_val: 100 +epoch_save: 1 + +#resume: 60 +#start_epoch: 60 diff --git a/configs/multi_mo_multi_task_sar_prompt.yaml b/configs/multi_mo_multi_task_sar_prompt.yaml new file mode 100644 index 0000000000000000000000000000000000000000..39edb9a93096a7fcc055cc218a0991abbe5818f6 --- /dev/null +++ b/configs/multi_mo_multi_task_sar_prompt.yaml @@ -0,0 +1,174 @@ +train_dataset: + dataset: + name: paired-image-folders + args: +# root_path_1: ./ISAID/train/trainprompt/sub_images +# root_path_1: ./ISAID/train/trainprompt/images + root_path_1: ./SAR_prompt/image +# root_path_1: ./SAM_DATA_UNIFY2/OVERALL/split_image +# root_path_1: ./SAM_DATA_UNIFY2/ISAID/split_image +# root_path_1: [{'ISAID': './SAM_DATA_UNIFY2/ISAID/split_image', 'WHU': './SAM_DATA_UNIFY2/WHU-OPT/split_images'}] +# root_path_1: [{'Decoder1': "/workspace/SAM_DATA_UNIFY3/Decoder1/split_image/", 'Decoder2': "/workspace/SAM_DATA_UNIFY3/Decoder2/split_image/"}] +# root_path_1: [{'Decoder1': "/workspace/SAM_DATA_UNIFY4/Potsdam/image1/", 'Decoder2': "/workspace/SAM_DATA_UNIFY4/Decoder2/image/"}] +# root_path_2: ./ISAID/train/trainprompt/sub_gt + root_path_2: ./SAR_prompt/gt +# root_path_2: ./SAM_DATA_UNIFY2/OVERALL/split_gt +# root_path_2: ./SAM_DATA_UNIFY2/ISAID/split_gt +# root_path_2: [{'ISAID': './SAM_DATA_UNIFY2/ISAID/split_gt', 'WHU': './SAM_DATA_UNIFY2/WHU-OPT/split_gt'}] +# root_path_2: [{'Decoder1': "/workspace/SAM_DATA_UNIFY3/Decoder1/split_gt/", 'Decoder2': "/workspace/SAM_DATA_UNIFY3/Decoder2/split_gt/"}] +# root_path_2: [{'Decoder1': "/workspace/SAM_DATA_UNIFY4/Potsdam/gt1/", 'Decoder2': "/workspace/SAM_DATA_UNIFY4/Decoder2/gt/"}] + cache: none + split_key: train + wrapper: + name: train + args: + inp_size: 1024 + augment: false +# batch_size: 2 + batch_size: 1 + +val_dataset: + dataset: + name: paired-image-folders + args: +# root_path_1: ./ISAID/train/trainprompt/images + root_path_1: ./SAR_prompt/image +# root_path_1: [{'ISAID': './SAM_DATA_UNIFY2/ISAID/split_image', 'WHU': './SAM_DATA_UNIFY2/WHU-OPT/split_images'}] +# root_path_1: [{'Decoder1': "/workspace/SAM_DATA_UNIFY3/Decoder1/split_image/", 'Decoder2': "/workspace/SAM_DATA_UNIFY3/Decoder2/split_image/"}] +# root_path_1: [{'Decoder1': "/workspace/SAM_DATA_UNIFY4/Potsdam/image1/", 'Decoder2': "/workspace/SAM_DATA_UNIFY4/Decoder2/image/"}] +# root_path_2: ./ISAID/train/trainprompt/gt + root_path_2: ./SAR_prompt/gt +# root_path_2: [{'ISAID': './SAM_DATA_UNIFY2/ISAID/split_gt', 'WHU': './SAM_DATA_UNIFY2/WHU-OPT/split_gt'}] +# root_path_2: [{'Decoder1': "/workspace/SAM_DATA_UNIFY3/Decoder1/split_gt/", 'Decoder2': "/workspace/SAM_DATA_UNIFY3/Decoder2/split_gt/"}] +# root_path_2: [{'Decoder1': "/workspace/SAM_DATA_UNIFY4/Potsdam/gt1/", 'Decoder2': "/workspace/SAM_DATA_UNIFY4/Decoder2/gt/"}] + cache: none + split_key: test + wrapper: + name: val + args: + inp_size: 1024 +# batch_size: 2 + batch_size: 1 + +test_dataset: + dataset: + name: paired-image-folders + args: +# root_path_1: ./ISAID/train/trainprompt/images +# root_path_1: ./ISAID/train/trainprompt/sub_images + root_path_1: ./save/SAR_prompt/image +# root_path_1: ./SAM_DATA_UNIFY/Vaihingen/split_image +# root_path_1: ./SAM_DATA_UNIFY/SAR2020/split_image_ov500 +# root_path_1: ./SAM_DATA_UNIFY/POLARIS_SAR/split_image +# root_path_1: ./SAM_DATA_UNIFY/Overall_Update/split_image +# root_path_1: ./SAM_DATA_UNIFY2/ISAID/split_image +# root_path_1: ./SAM_DATA_UNIFY2/whu-sar-test/split_image +# root_path_1: ./SAM_DATA_UNIFY2/WHU-SAR/split_image +# root_path_1: ./SAM_DATA_UNIFY2/WHU_ALL/split_image +# root_path_1: ./SAM_DATA_UNIFY3/WHU_SAR/split_image +# root_path_1: ./SAM_DATA_UNIFY3/WHU_OPT/split_image +# root_path_1: ./SAM_DATA_UNIFY3/ISAID/split_image +# root_path_1: ./SAM_DATA_UNIFY3/GANFEN/split_image +# root_path_1: ./SAM_DATA_UNIFY4/SAR2020/split_image_ov500 + +# root_path_2: ./ISAID/train/trainprompt/gt +# root_path_2: ./ISAID/train/trainprompt/sub_gt + root_path_2: ./save/SAR_prompt/gt +# root_path_2: ./SAM_DATA_UNIFY/Vaihingen/split_gt +# root_path_2: ./SAM_DATA_UNIFY2/ISAID/split_gt +# root_path_2: ./SAM_DATA_UNIFY/POLARIS_SAR/split_gt +# root_path_2: ./SAM_DATA_UNIFY/Overall_Update/split_gt +# root_path_2: ./SAM_DATA_UNIFY2/ISAID/split_gt +# root_path_2: ./SAM_DATA_UNIFY2/whu-sar-test/split_gt +# root_path_2: ./SAM_DATA_UNIFY2/WHU-SAR/split_gt +# root_path_2: ./SAM_DATA_UNIFY2/WHU_ALL/split_gt +# root_path_2: ./SAM_DATA_UNIFY3/WHU_SAR/split_gt +# root_path_2: ./SAM_DATA_UNIFY3/WHU_OPT/split_gt +# root_path_2: ./SAM_DATA_UNIFY3/ISAID/split_gt +# root_path_2: ./SAM_DATA_UNIFY3/GANFEN/gt_decoder1 +# root_path_2: ./SAM_DATA_UNIFY3/GANFEN/gt_decoder2 +# root_path_2: ./SAM_DATA_UNIFY4/SAR2020/gt_decoder2 + cache: none + split_key: test + wrapper: + name: val + args: +# inp_size: 1024 + inp_size: 1024 + batch_size: 1 + +#eval_type: cod +eval_type: f1 +#sam_checkpoint: ./pretrained/sam_vit_l_0b3195.pth +#sam_checkpoint: sam_vit_h_4b8939.pth +sam_checkpoint: ./save/_multi_mo_multi_task_0626/model_epoch_last.pth +#sam_checkpoint: ./save/_multi_mo_multi_task_0626/model_epoch_last.pth +data_norm: + inp: + sub: + - 0.5 + div: + - 0.5 + gt: + sub: + - 0.5 + div: + - 0.5 + gt_rgb: + sub: + - 0.5 + div: + - 0.5 +model: + name: sam + args: + inp_size: 1024 +# loss: iou + loss: cr + encoder_mode: + name: sam + img_size: 1024 + mlp_ratio: 4 + patch_size: 16 + qkv_bias: true + use_rel_pos: true + window_size: 14 + out_chans: 256 + scale_factor: 32 + input_type: fft + freq_nums: 0.25 + prompt_type: highpass + prompt_embed_dim: 256 + tuning_stage: 1234 + handcrafted_tune: true + embedding_tune: true + adaptor: adaptor + embed_dim: 1280 + depth: 32 + num_heads: 16 + global_attn_indexes: + - 7 + - 15 + - 23 + - 31 +optimizer: + name: adamw + args: +# lr: 0.0002 +# lr: 0.00002 +# lr: 0.00004 +# lr: 0.00008 + lr: 0.0002 +lr_min: 1.0e-8 +#epoch_max: 20 +epoch_max: 200 + +multi_step_lr: + milestones: + - 1 + gamma: 0.1 +epoch_val: 200 +epoch_save: 1 + +#resume: 60 +#start_epoch: 60 diff --git a/datasets/__init__.py b/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6eb69ac5ef38dc11ffa93eac4943d62673b81287 --- /dev/null +++ b/datasets/__init__.py @@ -0,0 +1,3 @@ +from .datasets import register, make +from . import image_folder +from . import wrappers diff --git a/datasets/__pycache__/__init__.cpython-310.pyc b/datasets/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc80fd1243b0dfc247f6b1bbea326d7d612e7494 Binary files /dev/null and b/datasets/__pycache__/__init__.cpython-310.pyc differ diff --git a/datasets/__pycache__/__init__.cpython-37.pyc b/datasets/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd97b9fdd7bdef2d27334d5a11cb2de589becad8 Binary files /dev/null and b/datasets/__pycache__/__init__.cpython-37.pyc differ diff --git a/datasets/__pycache__/datasets.cpython-310.pyc b/datasets/__pycache__/datasets.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4838da5b59eb62cbe04908ee238c9d7918e64a7d Binary files /dev/null and b/datasets/__pycache__/datasets.cpython-310.pyc differ diff --git a/datasets/__pycache__/datasets.cpython-37.pyc b/datasets/__pycache__/datasets.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25c961d09daa19be7b8370764c22b90ac6504ac4 Binary files /dev/null and b/datasets/__pycache__/datasets.cpython-37.pyc differ diff --git a/datasets/__pycache__/image_folder.cpython-310.pyc b/datasets/__pycache__/image_folder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9dd0e724162669d67a7a0c37355a8ec086e15e50 Binary files /dev/null and b/datasets/__pycache__/image_folder.cpython-310.pyc differ diff --git a/datasets/__pycache__/image_folder.cpython-37.pyc b/datasets/__pycache__/image_folder.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1412c923627a4bdcc49b4f53160acd3e0bd5047d Binary files /dev/null and b/datasets/__pycache__/image_folder.cpython-37.pyc differ diff --git a/datasets/__pycache__/wrappers.cpython-310.pyc b/datasets/__pycache__/wrappers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b51bf7cfe8198ce39d29d082d4c9da879b2167b Binary files /dev/null and b/datasets/__pycache__/wrappers.cpython-310.pyc differ diff --git a/datasets/__pycache__/wrappers.cpython-37.pyc b/datasets/__pycache__/wrappers.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe736d4dc22b3b9f3a7ba2793ee2b8ad83bc2f77 Binary files /dev/null and b/datasets/__pycache__/wrappers.cpython-37.pyc differ diff --git a/datasets/data_loader_multi_tasks.py b/datasets/data_loader_multi_tasks.py new file mode 100644 index 0000000000000000000000000000000000000000..d1337cd7b9a582ac534509acb6202ebcb5f19e3c --- /dev/null +++ b/datasets/data_loader_multi_tasks.py @@ -0,0 +1,26 @@ + +def build_loader_simmim(config): + ############ single model ##################### + # transform = SimMIMTransform(config) + # dataset = ImageFolder(config.DATA.DATA_PATH, transform) + # sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) + # dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True, collate_fn=collate_fn) + + ############## multi model #################### + datasets = [] + ### 数据增强 ###### + model_paths = config.DATA.TYPE_PATH[0] + for i in model_paths.keys(): + a = config.DATA.SCALE[0][i].split(',') + scale_model = (float(a[0].split('(')[1]) ,float(a[1].split(')')[0])) + transform = SimMIMTransform(config, config.DATA.NORM[0][i], scale_model) + dataset = CachedImageFolder(model_paths[i], transform = transform, model = i) + datasets.append(dataset) + multi_task_train_dataset = MultiTaskDataset(datasets) + print(len(datasets)) + multi_task_batch_sampler = DistrubutedMultiTaskBatchSampler(datasets, batch_size=config.DATA.BATCH_SIZE, num_replicas=dist.get_world_size(), rank=dist.get_rank(), mix_opt=0, extra_task_ratio=0, drop_last=True ,shuffle =True) + dataloader = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, collate_fn=collate_fn) + # dataloader = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler, pin_memory=True, collate_fn=collate_fn) + print(len(dataloader)) + + return dataloader \ No newline at end of file diff --git a/datasets/data_simmim_pt.py b/datasets/data_simmim_pt.py new file mode 100644 index 0000000000000000000000000000000000000000..2de9d2165c185e431cbf4fb73b8f51dc21316b58 --- /dev/null +++ b/datasets/data_simmim_pt.py @@ -0,0 +1,271 @@ +# -------------------------------------------------------- +# SimMIM +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Zhenda Xie +# -------------------------------------------------------- + +import math +import random +import numpy as np + +import torch +import torch.distributed as dist +import torchvision.transforms as T +from torch.utils.data import DataLoader, DistributedSampler +from torch.utils.data._utils.collate import default_collate +from torchvision.datasets import ImageFolder +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torch.utils.data import Dataset, BatchSampler +from torchvision.io import read_image +from .cached_image_folder import CachedImageFolder + +class MultiTaskDataset(Dataset): + """ + useage example: + train_datasets = [SemData_Single(), SemData_Single()] + multi_task_train_dataset = MultiTaskDataset(train_datasets) + multi_task_batch_sampler = MultiTaskBatchSampler(train_datasets, batch_size=4, mix_opt=0, extra_task_ratio=0, drop_last=True) + multi_task_train_data = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler) + for i, (task_id, input, target) in enumerate(multi_task_train_data): + pre = model(input) + """ + def __init__(self, datasets): + self._datasets = datasets + task_id_2_data_set_dic = {} + for i, dataset in enumerate(datasets): + task_id = i + assert task_id not in task_id_2_data_set_dic, "Duplicate task_id %s" % task_id + task_id_2_data_set_dic[task_id] = dataset + + self._task_id_2_data_set_dic = task_id_2_data_set_dic + + def __len__(self): + return sum(len(dataset) for dataset in self._datasets) + + def __getitem__(self, idx): + task_id, sample_id = idx + return self._task_id_2_data_set_dic[task_id][sample_id] + +class DistrubutedMultiTaskBatchSampler(BatchSampler): + """ + datasets: class the class of the Dataset + batch_size: int + mix_opt: int mix_opt ==0 shuffle all_task; mix_opt ==1 shuffle extra_task + extra_task_ratio(float, optional): the rate between task one and extra task + drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, + if the dataset size is not divisible by the batch size. If ``False`` and + the size of dataset is not divisible by the batch size, then the last batch + will be smaller. (default: ``True``) + """ + def __init__(self, datasets, batch_size, num_replicas, rank, mix_opt=0, extra_task_ratio=0, drop_last=True,shuffle = True): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + if rank >= num_replicas or rank < 0: + raise ValueError( + "Invalid rank {}, rank should be in the interval" + " [0, {}]".format(rank, num_replicas - 1)) + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + assert mix_opt in [0, 1], 'mix_opt must equal 0 or 1' + assert extra_task_ratio >= 0, 'extra_task_ratio must greater than 0' + self._datasets = datasets + self._batch_size = batch_size + self._mix_opt = mix_opt + self._extra_task_ratio = extra_task_ratio + self._drop_last = drop_last + train_data_list = [] + self.shuffle = shuffle + for dataset in datasets: + print(len(dataset)) + train_data_list.append(self._get_index_batches(len(dataset), batch_size, self._drop_last)) + + ######### 一个列表里存n个dataset的数据,数据也以列表形式存在,一个dataset的列表里面把数据划分成了不同的batch的index + self._train_data_list = train_data_list + self.total_len = sum(len(train_data) for train_data in self._train_data_list) + + ######### DDP ###################### + if self._drop_last and self.total_len % self.num_replicas != 0: # type: ignore[arg-type] + self.num_samples = math.ceil( + (self.total_len - self.num_replicas) / self.num_replicas # type: ignore[arg-type] + ) + else: + self.num_samples = math.ceil(self.total_len / self.num_replicas) # type: ignore[arg-type] + + self.total_size = self.num_samples * self.num_replicas + self.epoch = 0 + self.seed = 0 + + def set_epoch(self, epoch): + self.epoch = epoch + + @staticmethod + def _get_index_batches(dataset_len, batch_size, drop_last): + # index_batches = [list(range(i, min(i+batch_size, dataset_len))) for i in range(0, dataset_len, batch_size)] + index = list(range(dataset_len)) + if drop_last and dataset_len % batch_size: + del index[-(dataset_len % batch_size):] + index_batches = [index[i:i+batch_size] for i in range(0, len(index), batch_size)] + return index_batches + + def __len__(self): + # return sum(len(train_data) for train_data in self._train_data_list) + return self.num_samples + + def __iter__(self): + all_iters = [iter(item) for item in self._train_data_list] + all_indices = self._gen_task_indices(self._train_data_list, self._mix_opt, self._extra_task_ratio) + + ######### DDP ###################### + random.shuffle(all_indices) + all_indices = all_indices[self.rank:self.total_size:self.num_replicas] + assert len(all_indices) == self.num_samples + + for local_task_idx in all_indices: + # task_id = self._datasets[local_task_idx].get_task_id() + batch = next(all_iters[local_task_idx]) + # batch = batch[self.rank:len(batch):self.num_replicas] + # print(local_task_idx) + yield [(local_task_idx, sample_id) for sample_id in batch] + # yield iter(batch) + + @staticmethod + def _gen_task_indices(train_data_list, mix_opt, extra_task_ratio): + + ########## accoding to the number of models ########### + all_indices = [] + for i in range(len(train_data_list)): + all_indices += [i] * len(train_data_list[i]) + # print(all_indices) + return all_indices + # def set_epoch(self, epoch) + # r""" + # Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas + # use a different random ordering for each epoch. Otherwise, the next iteration of this + # sampler will yield the same ordering. + + # Args: + # epoch (int): Epoch number. + # """ + # self.epoch = epoch + + +class MaskGenerator: + def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6): + self.input_size = input_size + self.mask_patch_size = mask_patch_size + self.model_patch_size = model_patch_size + self.mask_ratio = mask_ratio + + assert self.input_size % self.mask_patch_size == 0 + assert self.mask_patch_size % self.model_patch_size == 0 + + self.rand_size = self.input_size // self.mask_patch_size + self.scale = self.mask_patch_size // self.model_patch_size + + self.token_count = self.rand_size ** 2 + self.mask_count = int(np.ceil(self.token_count * self.mask_ratio)) + + def __call__(self): + mask_idx = np.random.permutation(self.token_count)[:self.mask_count] + mask = np.zeros(self.token_count, dtype=int) + mask[mask_idx] = 1 + + mask = mask.reshape((self.rand_size, self.rand_size)) + mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1) + + return mask + + +class ZeroOneNormalize(object): + def __call__(self, img): + return img.float().div(255) + +class SimMIMTransform: + def __init__(self, config, NORM, SCALE): + self.transform_img = T.Compose([ + # T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + # T.RandomResizedCrop(config.DATA.IMG_SIZE, scale=(0.67, 1.), ratio=(3. / 4., 4. / 3.)), + # T.RandomHorizontalFlip(), + # T.ToTensor(), + # T.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN),std=torch.tensor(IMAGENET_DEFAULT_STD)), + + T.RandomResizedCrop(config.DATA.IMG_SIZE, scale=SCALE, ratio=(3. / 4., 4. / 3.)), + T.RandomHorizontalFlip(), + ZeroOneNormalize(), + T.Normalize(mean=torch.tensor(NORM[0]),std=torch.tensor(NORM[1])), + ]) + + if config.MODEL.TYPE in ['swin', 'swinv2']: + model_patch_size=config.MODEL.SWIN.PATCH_SIZE + else: + raise NotImplementedError + + self.mask_generator = MaskGenerator( + input_size=config.DATA.IMG_SIZE, + mask_patch_size=config.DATA.MASK_PATCH_SIZE, + model_patch_size=model_patch_size, + mask_ratio=config.DATA.MASK_RATIO, + ) + + def __call__(self, img): + img = self.transform_img(img) + mask = self.mask_generator() + + return img, mask + +def collate_fn(batch): + # print(len(batch)) + # print('*'*10) + # print(batch[0][0]) + # print('#'*10) + # print(batch[0][1]) + # batch = list(filter(lambda x: x[0][0] is not None, batch)) + # if len(batch) == 0: return torch.Tensor() + + if not isinstance(batch[0][0], tuple): + return default_collate(batch) + else: + batch_num = len(batch) + ret = [] + for item_idx in range(len(batch[0][0])): + if batch[0][0][item_idx] is None: + ret.append(None) + else: + ret.append(default_collate([batch[i][0][item_idx] for i in range(batch_num)])) + ret.append(default_collate([batch[i][1] for i in range(batch_num)])) + return ret + + +def build_loader_simmim(config): + ############ single model ##################### + # transform = SimMIMTransform(config) + # dataset = ImageFolder(config.DATA.DATA_PATH, transform) + # sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) + # dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True, collate_fn=collate_fn) + + ############## multi model #################### + datasets = [] + ### 数据增强 ###### + model_paths = config.DATA.TYPE_PATH[0] + for i in model_paths.keys(): + a = config.DATA.SCALE[0][i].split(',') + scale_model = (float(a[0].split('(')[1]),float(a[1].split(')')[0])) + transform = SimMIMTransform(config, config.DATA.NORM[0][i], scale_model) + dataset = CachedImageFolder(model_paths[i], transform = transform, model = i) + datasets.append(dataset) + multi_task_train_dataset = MultiTaskDataset(datasets) + print(len(datasets)) + multi_task_batch_sampler = DistrubutedMultiTaskBatchSampler(datasets, batch_size=config.DATA.BATCH_SIZE, num_replicas=dist.get_world_size(), rank=dist.get_rank(), mix_opt=0, extra_task_ratio=0, drop_last=True,shuffle =True) + dataloader = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, collate_fn=collate_fn) + # dataloader = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler, pin_memory=True, collate_fn=collate_fn) + print(len(dataloader)) + + return dataloader \ No newline at end of file diff --git a/datasets/datasets.py b/datasets/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..30b8cd98a710186d7a2e6693554ee6719e3a502b --- /dev/null +++ b/datasets/datasets.py @@ -0,0 +1,21 @@ +import copy + + +datasets = {} + + +def register(name): + def decorator(cls): + datasets[name] = cls + return cls + return decorator + + +def make(dataset_spec, args=None): + if args is not None: + dataset_args = copy.deepcopy(dataset_spec['args']) + dataset_args.update(args) + else: + dataset_args = dataset_spec['args'] + dataset = datasets[dataset_spec['name']](**dataset_args) + return dataset diff --git a/datasets/image_folder.py b/datasets/image_folder.py new file mode 100644 index 0000000000000000000000000000000000000000..85734cb252816e1c6cbfce235e2beb11ed07082f --- /dev/null +++ b/datasets/image_folder.py @@ -0,0 +1,370 @@ +import os +import json +from PIL import Image + +import pickle +import imageio +import numpy as np +import torch +from torch.utils.data import Dataset +from torchvision import transforms +import random +from datasets import register + +import math +import torch.distributed as dist +from torch.utils.data import BatchSampler + +from torch.utils.data._utils.collate import default_collate + +@register('image-folder') +class ImageFolder(Dataset): + def __init__(self, path, split_file=None, split_key=None, first_k=None, size=None, + repeat=1, cache='none', mask=False): + self.repeat = repeat + self.cache = cache + self.path = path + self.Train = False + self.split_key = split_key + + self.size = size + self.mask = mask + if self.mask: + self.img_transform = transforms.Compose([ + transforms.Resize((self.size, self.size), interpolation=Image.NEAREST), + transforms.ToTensor(), + ]) + else: + self.img_transform = transforms.Compose([ + transforms.Resize((self.size, self.size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + + if split_file is None: + filenames = sorted(os.listdir(path)) + else: + with open(split_file, 'r') as f: + filenames = json.load(f)[split_key] + if first_k is not None: + filenames = filenames[:first_k] + + self.files = [] + + for filename in filenames: + file = os.path.join(path, filename) + self.append_file(file) + + def append_file(self, file): + if self.cache == 'none': + self.files.append(file) + elif self.cache == 'in_memory': + self.files.append(self.img_process(file)) + + def __len__(self): + return len(self.files) * self.repeat + + def __getitem__(self, idx): + x = self.files[idx % len(self.files)] + + if self.cache == 'none': + return self.img_process(x) + elif self.cache == 'in_memory': + return x + + def img_process(self, file): + if self.mask: + # return Image.open(file).convert('L') + return file + else: + return Image.open(file).convert('RGB') + +@register('paired-image-folders') +class PairedImageFolders(Dataset): + + def __init__(self, root_path_1, root_path_2, **kwargs): + self.dataset_1 = ImageFolder(root_path_1, **kwargs) + self.dataset_2 = ImageFolder(root_path_2, **kwargs, mask=True) + + def __len__(self): + return len(self.dataset_1) + + def __getitem__(self, idx): + return self.dataset_1[idx], self.dataset_2[idx] + +class ImageFolder_multi_task(Dataset): + def __init__(self, path, split_file=None, split_key=None, first_k=None, size=None, + repeat=1, cache='none', mask=False): + self.repeat = repeat + self.cache = cache + self.path = path + self.Train = False + self.split_key = split_key + + self.size = size + self.mask = mask + if self.mask: + self.img_transform = transforms.Compose([ + transforms.Resize((self.size, self.size), interpolation=Image.NEAREST), + transforms.ToTensor(), + ]) + else: + self.img_transform = transforms.Compose([ + transforms.Resize((self.size, self.size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + + if split_file is None: + filenames = sorted(os.listdir(path)) + else: + with open(split_file, 'r') as f: + filenames = json.load(f)[split_key] + if first_k is not None: + filenames = filenames[:first_k] + + self.files = [] + + for filename in filenames: + file = os.path.join(path, filename) + self.append_file(file) + + def append_file(self, file): + if self.cache == 'none': + self.files.append(file) + elif self.cache == 'in_memory': + self.files.append(self.img_process(file)) + + def __len__(self): + return len(self.files) * self.repeat + + def __getitem__(self, idx): + x = self.files[idx % len(self.files)] + + if self.cache == 'none': + return self.img_process(x) + elif self.cache == 'in_memory': + return x + + def img_process(self, file): + if self.mask: + # return Image.open(file).convert('L') + return file + else: + return Image.open(file).convert('RGB') + +@register('paired-image-folders-multi-task') +class PairedImageFolders_multi_task(Dataset): + + def __init__(self, root_path_1, root_path_2, model=None, **kwargs): + + self.dataset_1 = ImageFolder_multi_task(root_path_1, **kwargs) + self.dataset_2 = ImageFolder_multi_task(root_path_2, **kwargs, mask=True) + + def __len__(self): + return len(self.dataset_1) + + def __getitem__(self, idx): + return self.dataset_1[idx], self.dataset_2[idx] + + + + +# class MultiTaskDataset(Dataset): +# """ +# useage example: +# train_datasets = [SemData_Single(), SemData_Single()] +# multi_task_train_dataset = MultiTaskDataset(train_datasets) +# multi_task_batch_sampler = MultiTaskBatchSampler(train_datasets, batch_size=4, mix_opt=0, extra_task_ratio=0, drop_last=True) +# multi_task_train_data = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler) +# for i, (task_id, input, target) in enumerate(multi_task_train_data): +# pre = model(input) +# """ +# def __init__(self, datasets_image, datasets_gt): +# self._datasets = datasets_image +# task_id_2_image_set_dic = {} +# for i, dataset in enumerate(datasets_image): +# task_id = i +# assert task_id not in task_id_2_image_set_dic, "Duplicate task_id %s" % task_id +# task_id_2_image_set_dic[task_id] = dataset +# self.datasets_1 = task_id_2_image_set_dic +# +# task_id_2_gt_set_dic = {} +# for i, dataset in enumerate(datasets_gt): +# task_id = i +# assert task_id not in task_id_2_gt_set_dic, "Duplicate task_id %s" % task_id +# task_id_2_gt_set_dic[task_id] = dataset +# self.dataset_2 = task_id_2_gt_set_dic +# +# +# def __len__(self): +# return sum(len(dataset) for dataset in self._datasets) +# +# def __getitem__(self, idx): +# task_id, sample_id = idx +# # return self._task_id_2_data_set_dic[task_id][sample_id] +# return self.dataset_1[task_id][sample_id], self.dataset_2[task_id][sample_id] + +class MultiTaskDataset(Dataset): + """ + useage example: + train_datasets = [SemData_Single(), SemData_Single()] + multi_task_train_dataset = MultiTaskDataset(train_datasets) + multi_task_batch_sampler = MultiTaskBatchSampler(train_datasets, batch_size=4, mix_opt=0, extra_task_ratio=0, drop_last=True) + multi_task_train_data = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler) + for i, (task_id, input, target) in enumerate(multi_task_train_data): + pre = model(input) + """ + def __init__(self, datasets): + self._datasets = datasets + task_id_2_data_set_dic = {} + for i, dataset in enumerate(datasets): + task_id = i + assert task_id not in task_id_2_data_set_dic, "Duplicate task_id %s" % task_id + task_id_2_data_set_dic[task_id] = dataset + + self._task_id_2_data_set_dic = task_id_2_data_set_dic + + def __len__(self): + return sum(len(dataset) for dataset in self._datasets) + + def __getitem__(self, idx): + task_id, sample_id = idx + # print('----', idx, task_id, sample_id) + return self._task_id_2_data_set_dic[task_id][sample_id] + +def collate_fn(batch): + # print(len(batch)) + # print('*'*10) + # print(batch[0][0]) + # print('#'*10) + # print(batch[0][1]) + # batch = list(filter(lambda x: x[0][0] is not None, batch)) + # if len(batch) == 0: return torch.Tensor() + print('******------',batch) + if not isinstance(batch[0][0], tuple): + return default_collate(batch) + else: + batch_num = len(batch) + ret = [] + for item_idx in range(len(batch[0][0])): + if batch[0][0][item_idx] is None: + ret.append(None) + else: + ret.append(default_collate([batch[i][0][item_idx] for i in range(batch_num)])) + ret.append(default_collate([batch[i][1] for i in range(batch_num)])) + return ret + +class DistrubutedMultiTaskBatchSampler(BatchSampler): + """ + datasets: class the class of the Dataset + batch_size: int + mix_opt: int mix_opt ==0 shuffle all_task; mix_opt ==1 shuffle extra_task + extra_task_ratio(float, optional): the rate between task one and extra task + drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, + if the dataset size is not divisible by the batch size. If ``False`` and + the size of dataset is not divisible by the batch size, then the last batch + will be smaller. (default: ``True``) + """ + + def __init__(self, datasets, batch_size, num_replicas, rank, mix_opt=0, extra_task_ratio=0, drop_last=True, + shuffle=True): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + if rank >= num_replicas or rank < 0: + raise ValueError( + "Invalid rank {}, rank should be in the interval" + " [0, {}]".format(rank, num_replicas - 1)) + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + assert mix_opt in [0, 1], 'mix_opt must equal 0 or 1' + assert extra_task_ratio >= 0, 'extra_task_ratio must greater than 0' + # self._datasets = datasets + self._batch_size = batch_size + self._mix_opt = mix_opt + self._extra_task_ratio = extra_task_ratio + self._drop_last = drop_last + train_data_list = [] + self.shuffle = shuffle + for dataset in datasets: + print(len(dataset)) + train_data_list.append(self._get_index_batches(len(dataset), batch_size, self._drop_last)) + + ######### 一个列表里存n个dataset的数据,数据也以列表形式存在,一个dataset的列表里面把数据划分成了不同的batch的index + self._train_data_list = train_data_list + self.total_len = sum(len(train_data) for train_data in self._train_data_list) + + ######### DDP ###################### + if self._drop_last and self.total_len % self.num_replicas != 0: # type: ignore[arg-type] + self.num_samples = math.ceil( + (self.total_len - self.num_replicas) / self.num_replicas # type: ignore[arg-type] + ) + else: + self.num_samples = math.ceil(self.total_len / self.num_replicas) # type: ignore[arg-type] + + self.total_size = self.num_samples * self.num_replicas + self.epoch = 0 + self.seed = 0 + + def set_epoch(self, epoch): + # print('&&&&****') + self.epoch = epoch + + @staticmethod + def _get_index_batches(dataset_len, batch_size, drop_last): + # index_batches = [list(range(i, min(i+batch_size, dataset_len))) for i in range(0, dataset_len, batch_size)] + index = list(range(dataset_len)) + if drop_last and dataset_len % batch_size: + del index[-(dataset_len % batch_size):] + index_batches = [index[i:i + batch_size] for i in range(0, len(index), batch_size)] + return index_batches + + def __len__(self): + # return sum(len(train_data) for train_data in self._train_data_list) + return self.num_samples + + def __iter__(self): + all_iters = [iter(item) for item in self._train_data_list] + all_indices = self._gen_task_indices(self._train_data_list, self._mix_opt, self._extra_task_ratio) + + ######### DDP ###################### + random.shuffle(all_indices) + all_indices = all_indices[self.rank:self.total_size:self.num_replicas] + assert len(all_indices) == self.num_samples + + for local_task_idx in all_indices: + # task_id = self._datasets[local_task_idx].get_task_id() + batch = next(all_iters[local_task_idx]) + # batch = batch[self.rank:len(batch):self.num_replicas] + # print(local_task_idx) + yield [(local_task_idx, sample_id) for sample_id in batch] + # yield iter(batch) + + @staticmethod + def _gen_task_indices(train_data_list, mix_opt, extra_task_ratio): + + ########## accoding to the number of models ########### + all_indices = [] + for i in range(len(train_data_list)): + all_indices += [i] * len(train_data_list[i]) + # print(all_indices) + return all_indices + # def set_epoch(self, epoch) + # r""" + # Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas + # use a different random ordering for each epoch. Otherwise, the next iteration of this + # sampler will yield the same ordering. + + # Args: + # epoch (int): Epoch number. + # """ + # self.epoch = epoch diff --git a/datasets/wrappers.py b/datasets/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..a3eddc84839411dc6dad110f66f3729efbb284b5 --- /dev/null +++ b/datasets/wrappers.py @@ -0,0 +1,231 @@ + +import functools +import random +import math +from PIL import Image +import cv2 + +import numpy as np +import torch +from torch.utils.data import Dataset +from torchvision import transforms +import torchvision + +from datasets import register +import cv2 +from math import pi +from torchvision.transforms import InterpolationMode + +import torch.nn.functional as F +def to_mask(mask): + return transforms.ToTensor()( + transforms.Grayscale(num_output_channels=1)( + transforms.ToPILImage()(mask))) + + +def resize_fn(img, size): + return transforms.ToTensor()( + transforms.Resize(size)( + transforms.ToPILImage()(img))) + + +@register('val') +class ValDataset(Dataset): + def __init__(self, dataset, inp_size=None, augment=False): + self.dataset = dataset + self.inp_size = inp_size + self.augment = augment + + self.img_transform = transforms.Compose([ + # transforms.Resize((inp_size, inp_size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + self.mask_transform = transforms.Compose([ + transforms.Resize((inp_size, inp_size), interpolation=Image.NEAREST), + transforms.ToTensor(), + ]) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + img, mask = self.dataset[idx] + mask_name = mask + a = self.img_transform(img) + # b = self.mask_transform(mask) + + # print(idx, mask.filename) + # b = cv2.imread(mask.filename,cv2.IMREAD_UNCHANGED) + b = cv2.imread(mask,cv2.IMREAD_UNCHANGED) + return { + 'inp': self.img_transform(img), + 'gt': torch.tensor(b), + 'name': mask_name, + 'filp': False + # 'idx': idx + } + + +@register('train') +class TrainDataset(Dataset): + def __init__(self, dataset, size_min=None, size_max=None, inp_size=None, + augment=False, gt_resize=None): + self.dataset = dataset + self.size_min = size_min + if size_max is None: + size_max = size_min + self.size_max = size_max + self.augment = augment + self.gt_resize = gt_resize + + self.inp_size = inp_size + self.img_transform = transforms.Compose([ + transforms.Resize((self.inp_size, self.inp_size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + self.inverse_transform = transforms.Compose([ + transforms.Normalize(mean=[0., 0., 0.], + std=[1/0.229, 1/0.224, 1/0.225]), + transforms.Normalize(mean=[-0.485, -0.456, -0.406], + std=[1, 1, 1]) + ]) + self.mask_transform = transforms.Compose([ + transforms.Resize((self.inp_size, self.inp_size)), + transforms.ToTensor(), + ]) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + # print('lodd****',idx,self.dataset[idx]) + img, mask = self.dataset[idx] + mask_name = mask + # print('befor mask', mask) + #new add + # print(idx, mask.filename, img.size) + + # mask = cv2.imread(mask.filename, cv2.IMREAD_UNCHANGED) + mask = cv2.imread(mask, cv2.IMREAD_UNCHANGED) + # print('befor mask', mask) + # random filp + if random.random() < 0.5: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + # mask = mask.transpose(Image.FLIP_LEFT_RIGHT) + mask = cv2.flip(mask, 1) + + img = transforms.Resize((self.inp_size, self.inp_size))(img) + # mask = transforms.Resize((self.inp_size, self.inp_size), interpolation=InterpolationMode.NEAREST)(mask) + mask = torch.from_numpy(mask) + # print('behind mask', mask) + return { + 'inp': self.img_transform(img), + # 'gt': self.mask_transform(mask) + 'gt': mask, + 'name': mask_name, + # 'idx': idx + } + +@register('train_multi_task') +class TrainDataset(Dataset): + def __init__(self, dataset, size_min=None, size_max=None, inp_size=None, + augment=False, gt_resize=None): + self.dataset = dataset + self.size_min = size_min + if size_max is None: + size_max = size_min + self.size_max = size_max + self.augment = augment + self.gt_resize = gt_resize + + self.inp_size = inp_size + self.img_transform = transforms.Compose([ + transforms.Resize((self.inp_size, self.inp_size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + self.inverse_transform = transforms.Compose([ + transforms.Normalize(mean=[0., 0., 0.], + std=[1/0.229, 1/0.224, 1/0.225]), + transforms.Normalize(mean=[-0.485, -0.456, -0.406], + std=[1, 1, 1]) + ]) + self.mask_transform = transforms.Compose([ + transforms.Resize((self.inp_size, self.inp_size)), + transforms.ToTensor(), + ]) + + def __len__(self): + return len(self.dataset) + # return sum(len(dataset) for dataset in self.datasets) + + def __getitem__(self, idx): + # print('lodd****',idx,self.dataset[idx]) + # print('+++++',idx) + img, mask = self.dataset[idx] + # print('befor mask', mask) + #new add + # print('****',idx, mask) + mask_name = mask + mask = cv2.imread(mask, cv2.IMREAD_UNCHANGED) + + # print('****',mask) + # print('befor mask', mask) + # random filp + if random.random() < 0.5: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + # mask = mask.transpose(Image.FLIP_LEFT_RIGHT) + mask = cv2.flip(mask, 1) + + img = transforms.Resize((self.inp_size, self.inp_size))(img) + # mask = transforms.Resize((self.inp_size, self.inp_size), interpolation=InterpolationMode.NEAREST)(mask) + mask = torch.from_numpy(mask) + # print('behind mask', mask) + return { + 'inp': self.img_transform(img), + # 'gt': self.mask_transform(mask) + 'gt': mask, + 'name': mask_name + } + + +@register('val_multi_task') +class ValDataset(Dataset): + def __init__(self, dataset, inp_size=None, augment=False): + self.dataset = dataset + self.inp_size = inp_size + self.augment = augment + + self.img_transform = transforms.Compose([ + transforms.Resize((inp_size, inp_size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + self.mask_transform = transforms.Compose([ + transforms.Resize((inp_size, inp_size), interpolation=Image.NEAREST), + transforms.ToTensor(), + ]) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + img, mask = self.dataset[idx] + a = self.img_transform(img) + # b = self.mask_transform(mask) + mask_name = mask + # print(idx, mask.filename) + # b = cv2.imread(mask.filename,cv2.IMREAD_UNCHANGED) + b = cv2.imread(mask, cv2.IMREAD_UNCHANGED) + return { + 'inp': self.img_transform(img), + 'gt': torch.tensor(b), + 'name': mask_name + # 'idx': idx + } \ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..85bc7ff2d60d6f0d6557f3c805486e24b752bfd3 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,4 @@ +from .models import register, make +from . import sam +from . import sam_single + diff --git a/models/__pycache__/__init__.cpython-310.pyc b/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eaa5bfb7b7c287163b9f27c9174eef2df4448ab1 Binary files /dev/null and b/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/models/__pycache__/__init__.cpython-37.pyc b/models/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..543b3203bab808721226525ba0990e4b771af08b Binary files /dev/null and b/models/__pycache__/__init__.cpython-37.pyc differ diff --git a/models/__pycache__/iou_loss.cpython-37.pyc b/models/__pycache__/iou_loss.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0966b56f51b34aac7bbb2de21208db4b2ad65a6 Binary files /dev/null and b/models/__pycache__/iou_loss.cpython-37.pyc differ diff --git a/models/__pycache__/models.cpython-310.pyc b/models/__pycache__/models.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e037c0ae10f6076ef1a81c7fb7ac50ab775eb33 Binary files /dev/null and b/models/__pycache__/models.cpython-310.pyc differ diff --git a/models/__pycache__/models.cpython-37.pyc b/models/__pycache__/models.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b3708ffb4662d21d6097291fbdabe34af66f171 Binary files /dev/null and b/models/__pycache__/models.cpython-37.pyc differ diff --git a/models/__pycache__/sam.cpython-310.pyc b/models/__pycache__/sam.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab8a7e0ff466cab4fc3e11959095aeabb6158845 Binary files /dev/null and b/models/__pycache__/sam.cpython-310.pyc differ diff --git a/models/__pycache__/sam.cpython-37.pyc b/models/__pycache__/sam.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0d4cb226a4a9dbcdd277903ab04b88f2f62e034 Binary files /dev/null and b/models/__pycache__/sam.cpython-37.pyc differ diff --git a/models/__pycache__/sam_single.cpython-37.pyc b/models/__pycache__/sam_single.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0846931106ec2431ed9e4c2c27bfd3fda5a9498 Binary files /dev/null and b/models/__pycache__/sam_single.cpython-37.pyc differ diff --git a/models/__pycache__/utils_prompt.cpython-37.pyc b/models/__pycache__/utils_prompt.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d192cc249552dfc74be7818f01ae706ae83cd0b2 Binary files /dev/null and b/models/__pycache__/utils_prompt.cpython-37.pyc differ diff --git a/models/block.py b/models/block.py new file mode 100644 index 0000000000000000000000000000000000000000..c9136112e36197f93a542b2304647ebe27da038a --- /dev/null +++ b/models/block.py @@ -0,0 +1,128 @@ +from __future__ import print_function + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F + + +class MergeAndConv(nn.Module): + + def __init__(self, ic, oc, inner=32): + super().__init__() + + self.conv1 = nn.Conv2d(ic, inner, kernel_size=3, stride=1, padding=1) + self.bn = nn.BatchNorm2d(inner) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(inner, oc, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = self.conv2(self.bn(self.relu(self.conv1(x)))) + x = torch.sigmoid(x) + return x + + +class SideClassifer(nn.Module): + def __init__(self, ic, n_class=1, M=2, kernel_size=1): + super().__init__() + + sides = [] + for i in range(M): + sides.append(nn.Conv2d(ic, n_class, kernel_size=kernel_size)) + + self.sides = nn.ModuleList(sides) + + def forward(self, x): + return [fn(x) for fn in self.sides] + + +class UpsampleSKConv(nn.Module): + """docstring for UpsampleSKConvPlus""" + + def __init__(self, ic, oc, reduce=4): + super(UpsampleSKConv, self).__init__() + + self.relu = nn.ReLU(inplace=True) + self.prev = nn.Conv2d(ic, ic // reduce, kernel_size=3, stride=1, padding=1) + self.bn = nn.BatchNorm2d(ic // reduce) + + self.next = nn.Conv2d(ic // reduce, oc, kernel_size=1, stride=1) + self.bn2 = nn.BatchNorm2d(oc) + + self.sk = SKSPP(ic // reduce, ic // reduce, M=4) + + def forward(self, x): + x = F.interpolate(x, scale_factor=2) + + x = self.bn(self.relu(self.prev(x))) + + x = self.sk(x) + + x = self.bn2(self.relu(self.next(x))) + + return x + + +class SKSPP(nn.Module): + def __init__(self, features, WH, M=2, G=1, r=16, stride=1, L=32): + """ Constructor + Args: + features: input channel dimensionality. + WH: input spatial dimensionality, used for GAP kernel size. + M: the number of branchs. + G: num of convolution groups. + r: the radio for compute d, the length of z. + stride: stride, default 1. + L: the minimum dim of the vector z in paper, default 32. + """ + super(SKSPP, self).__init__() + d = max(int(features / r), L) + self.M = M # original + self.features = features + self.convs = nn.ModuleList([]) + + # 1,3,5,7 padding:[0,1,2,3] + for i in range(1, M): + self.convs.append(nn.Sequential( + nn.Conv2d(features, features, kernel_size=1 + i * 2, dilation=1 + i * 2, stride=stride, + padding=((1 + i * 2) * (i * 2) + 1) // 2, groups=G), + nn.BatchNorm2d(features), + nn.ReLU(inplace=False) + )) + # self.gap = nn.AvgPool2d(int(WH/stride)) + self.fc = nn.Linear(features, d) + self.fcs = nn.ModuleList([]) + for i in range(M): + self.fcs.append( + nn.Linear(d, features) + ) + self.softmax = nn.Softmax(dim=1) + + def forward(self, x): + + feas = torch.unsqueeze(x, dim=1) + + # F->conv1x1->conv3x3->conv5x5->conv7x7 + + for i, conv in enumerate(self.convs): + x = conv(x) + # if i == 0: + # feas = fea + # else: + feas = torch.cat([feas, torch.unsqueeze(x, dim=1)], dim=1) + + fea_U = torch.sum(feas, dim=1) + fea_s = fea_U.mean(-1).mean(-1) + fea_z = self.fc(fea_s) + + for i, fc in enumerate(self.fcs): + vector = fc(fea_z).unsqueeze_(dim=1) + if i == 0: + attention_vectors = vector + else: + attention_vectors = torch.cat([attention_vectors, vector], dim=1) + + attention_vectors = self.softmax(attention_vectors) + attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1) + fea_v = (feas * attention_vectors).sum(dim=1) + return fea_v \ No newline at end of file diff --git a/models/bn_helper.py b/models/bn_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..e89ebe51f2bf4ce04dcf46fcb3cdc7f8aef2eedd --- /dev/null +++ b/models/bn_helper.py @@ -0,0 +1,16 @@ +import torch +import functools + +if torch.__version__.startswith('0'): + from .sync_bn.inplace_abn.bn import InPlaceABNSync + BatchNorm2d = functools.partial(InPlaceABNSync, activation='none') + BatchNorm2d_class = InPlaceABNSync + relu_inplace = False +else: + BatchNorm2d_class = BatchNorm2d = torch.nn.SyncBatchNorm + relu_inplace = True + +import torch +BatchNorm2d = torch.nn.BatchNorm2d +BatchNorm2d_class = BatchNorm2d +relu_inplace = False \ No newline at end of file diff --git a/models/iou_loss.py b/models/iou_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..49051fb2964b109feadc4e0254c4a1213b7be266 --- /dev/null +++ b/models/iou_loss.py @@ -0,0 +1,21 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +################################################################### +# ########################## iou loss ############################# +################################################################### +class IOU(torch.nn.Module): + def __init__(self): + super(IOU, self).__init__() + + def _iou(self, pred, target): + pred = torch.sigmoid(pred) + inter = (pred * target).sum(dim=(2, 3)) + union = (pred + target).sum(dim=(2, 3)) - inter + iou = 1 - (inter / union) + + return iou.mean() + + def forward(self, pred, target): + return self._iou(pred, target) diff --git a/models/mmseg/__init__.py b/models/mmseg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e233b9d6c8b956d60e43dd97dd21775b2621748d --- /dev/null +++ b/models/mmseg/__init__.py @@ -0,0 +1,33 @@ +import mmcv + +from .version import __version__, version_info + +# MMCV_MIN = '1.1.4' +# MMCV_MAX = '1.3.0' + +MMCV_MIN = '1.1.4' +MMCV_MAX = '1.7.0' + + +def digit_version(version_str): + digit_version = [] + for x in version_str.split('.'): + if x.isdigit(): + digit_version.append(int(x)) + elif x.find('rc') != -1: + patch_version = x.split('rc') + digit_version.append(int(patch_version[0]) - 1) + digit_version.append(int(patch_version[1])) + return digit_version + + +mmcv_min_version = digit_version(MMCV_MIN) +mmcv_max_version = digit_version(MMCV_MAX) +mmcv_version = digit_version(mmcv.__version__) + + +assert (mmcv_min_version <= mmcv_version <= mmcv_max_version), \ + f'MMCV=={mmcv.__version__} is used but incompatible. ' \ + f'Please install mmcv>={mmcv_min_version}, <={mmcv_max_version}.' + +__all__ = ['__version__', 'version_info'] diff --git a/models/mmseg/__pycache__/__init__.cpython-310.pyc b/models/mmseg/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..252bb5752fbb97aaeac4743dc71566ec78e84aec Binary files /dev/null and b/models/mmseg/__pycache__/__init__.cpython-310.pyc differ diff --git a/models/mmseg/__pycache__/__init__.cpython-37.pyc b/models/mmseg/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fc1090ca1738cc4ad7732259fd0f7fd59d3c3ba Binary files /dev/null and b/models/mmseg/__pycache__/__init__.cpython-37.pyc differ diff --git a/models/mmseg/__pycache__/version.cpython-310.pyc b/models/mmseg/__pycache__/version.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..719edd8c3c47efc0800c2a5eae91b3ad09176482 Binary files /dev/null and b/models/mmseg/__pycache__/version.cpython-310.pyc differ diff --git a/models/mmseg/__pycache__/version.cpython-37.pyc b/models/mmseg/__pycache__/version.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..779e2dcd2adff953ffbd94e4397ed4c84909f850 Binary files /dev/null and b/models/mmseg/__pycache__/version.cpython-37.pyc differ diff --git a/models/mmseg/apis/__init__.py b/models/mmseg/apis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..170724be38de42daf2bc1a1910e181d68818f165 --- /dev/null +++ b/models/mmseg/apis/__init__.py @@ -0,0 +1,9 @@ +from .inference import inference_segmentor, init_segmentor, show_result_pyplot +from .test import multi_gpu_test, single_gpu_test +from .train import get_root_logger, set_random_seed, train_segmentor + +__all__ = [ + 'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor', + 'inference_segmentor', 'multi_gpu_test', 'single_gpu_test', + 'show_result_pyplot' +] diff --git a/models/mmseg/apis/inference.py b/models/mmseg/apis/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..20c20dccda2837fb54e8ceaa2aa7e74404c085f7 --- /dev/null +++ b/models/mmseg/apis/inference.py @@ -0,0 +1,118 @@ +import matplotlib.pyplot as plt +import mmcv +import torch +from mmcv.parallel import collate, scatter +from mmcv.runner import load_checkpoint + +from mmseg.datasets.pipelines import Compose +from mmseg.models import build_segmentor + + +def init_segmentor(config, checkpoint=None, device='cuda:0'): + """Initialize a segmentor from config file. + + Args: + config (str or :obj:`mmcv.Config`): Config file path or the config + object. + checkpoint (str, optional): Checkpoint path. If left as None, the model + will not load any weights. + device (str, optional) CPU/CUDA device option. Default 'cuda:0'. + Use 'cpu' for loading model on CPU. + Returns: + nn.Module: The constructed segmentor. + """ + if isinstance(config, str): + config = mmcv.Config.fromfile(config) + elif not isinstance(config, mmcv.Config): + raise TypeError('config must be a filename or Config object, ' + 'but got {}'.format(type(config))) + config.model.pretrained = None + config.model.train_cfg = None + model = build_segmentor(config.model, test_cfg=config.get('test_cfg')) + if checkpoint is not None: + checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') + model.CLASSES = checkpoint['meta']['CLASSES'] + model.PALETTE = checkpoint['meta']['PALETTE'] + model.cfg = config # save the config in the model for convenience + model.to(device) + model.eval() + return model + + +class LoadImage: + """A simple pipeline to load image.""" + + def __call__(self, results): + """Call function to load images into results. + + Args: + results (dict): A result dict contains the file name + of the image to be read. + + Returns: + dict: ``results`` will be returned containing loaded image. + """ + + if isinstance(results['img'], str): + results['filename'] = results['img'] + results['ori_filename'] = results['img'] + else: + results['filename'] = None + results['ori_filename'] = None + img = mmcv.imread(results['img']) + results['img'] = img + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + return results + + +def inference_segmentor(model, img): + """Inference image(s) with the segmentor. + + Args: + model (nn.Module): The loaded segmentor. + imgs (str/ndarray or list[str/ndarray]): Either image files or loaded + images. + + Returns: + (list[Tensor]): The segmentation result. + """ + cfg = model.cfg + device = next(model.parameters()).device # model device + # build the data pipeline + test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:] + test_pipeline = Compose(test_pipeline) + # prepare data + data = dict(img=img) + data = test_pipeline(data) + data = collate([data], samples_per_gpu=1) + if next(model.parameters()).is_cuda: + # scatter to specified GPU + data = scatter(data, [device])[0] + else: + data['img_metas'] = [i.data[0] for i in data['img_metas']] + + # forward the model + with torch.no_grad(): + result = model(return_loss=False, rescale=True, **data) + return result + + +def show_result_pyplot(model, img, result, palette=None, fig_size=(15, 10)): + """Visualize the segmentation results on the image. + + Args: + model (nn.Module): The loaded segmentor. + img (str or np.ndarray): Image filename or loaded image. + result (list): The segmentation result. + palette (list[list[int]]] | None): The palette of segmentation + map. If None is given, random palette will be generated. + Default: None + fig_size (tuple): Figure size of the pyplot figure. + """ + if hasattr(model, 'module'): + model = model.module + img = model.show_result(img, result, palette=palette, show=False) + plt.figure(figsize=fig_size) + plt.imshow(mmcv.bgr2rgb(img)) + plt.show() diff --git a/models/mmseg/apis/test.py b/models/mmseg/apis/test.py new file mode 100644 index 0000000000000000000000000000000000000000..837f8d5ae51b9b8ea78af5a6921840bf176e0108 --- /dev/null +++ b/models/mmseg/apis/test.py @@ -0,0 +1,235 @@ +import os.path as osp +import pickle +import shutil +import tempfile + +import mmcv +import numpy as np +import torch +import torch.distributed as dist +from mmcv.image import tensor2imgs +from mmcv.runner import get_dist_info +from IPython import embed +from mmseg.ops import resize + +def np2tmp(array, temp_file_name=None): + """Save ndarray to local numpy file. + + Args: + array (ndarray): Ndarray to save. + temp_file_name (str): Numpy file name. If 'temp_file_name=None', this + function will generate a file name with tempfile.NamedTemporaryFile + to save ndarray. Default: None. + + Returns: + str: The numpy file name. + """ + + if temp_file_name is None: + temp_file_name = tempfile.NamedTemporaryFile( + suffix='.npy', delete=False).name + np.save(temp_file_name, array) + return temp_file_name + + +def single_gpu_test(model, + data_loader, + show=False, + out_dir=None, + efficient_test=False): + """Test with single GPU. + + Args: + model (nn.Module): Model to be tested. + data_loader (utils.data.Dataloader): Pytorch data loader. + show (bool): Whether show results during infernece. Default: False. + out_dir (str, optional): If specified, the results will be dumped into + the directory to save output results. + efficient_test (bool): Whether save the results as local numpy files to + save CPU memory during evaluation. Default: False. + + Returns: + list: The prediction results. + """ + + model.eval() + results = [] + dataset = data_loader.dataset + prog_bar = mmcv.ProgressBar(len(dataset)) + for i, data in enumerate(data_loader): + with torch.no_grad(): + result = model(return_loss=False, **data) + + if show or out_dir: + img_tensor = data['img'][0] + img_metas = data['img_metas'][0].data[0] + imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg']) + assert len(imgs) == len(img_metas) + + for img, img_meta in zip(imgs, img_metas): + h, w, _ = img_meta['img_shape'] + img_show = img[:h, :w, :] + + ori_h, ori_w = img_meta['ori_shape'][:-1] + img_show = mmcv.imresize(img_show, (ori_w, ori_h)) + + if out_dir: + out_file = osp.join(out_dir, img_meta['ori_filename']) + else: + out_file = None + + model.module.show_result( + img_show, + result, + palette=dataset.PALETTE, + show=show, + out_file=out_file) + + if isinstance(result, list): + if efficient_test: + result = [np2tmp(_) for _ in result] + results.extend(result) + else: + if efficient_test: + result = np2tmp(result) + results.append(result) + + batch_size = data['img'][0].size(0) + for _ in range(batch_size): + prog_bar.update() + return results + + +def multi_gpu_test(model, + data_loader, + tmpdir=None, + gpu_collect=False, + efficient_test=False): + """Test model with multiple gpus. + + This method tests model with multiple gpus and collects the results + under two different modes: gpu and cpu modes. By setting 'gpu_collect=True' + it encodes results to gpu tensors and use gpu communication for results + collection. On cpu mode it saves the results on different gpus to 'tmpdir' + and collects them by the rank 0 worker. + + Args: + model (nn.Module): Model to be tested. + data_loader (utils.data.Dataloader): Pytorch data loader. + tmpdir (str): Path of directory to save the temporary results from + different gpus under cpu mode. + gpu_collect (bool): Option to use either gpu or cpu to collect results. + efficient_test (bool): Whether save the results as local numpy files to + save CPU memory during evaluation. Default: False. + + Returns: + list: The prediction results. + """ + + model.eval() + results = [] + dataset = data_loader.dataset + rank, world_size = get_dist_info() + if rank == 0: + prog_bar = mmcv.ProgressBar(len(dataset)) + for i, data in enumerate(data_loader): + with torch.no_grad(): + result = model(return_loss=False, rescale=True, **data) + + if isinstance(result, list): + if efficient_test: + result = [np2tmp(_) for _ in result] + results.extend(result) + else: + if efficient_test: + result = np2tmp(result) + results.append(result) + + if rank == 0: + batch_size = data['img'][0].size(0) + for _ in range(batch_size * world_size): + prog_bar.update() + + # collect results from all ranks + if gpu_collect: + results = collect_results_gpu(results, len(dataset)) + else: + results = collect_results_cpu(results, len(dataset), tmpdir) + return results + + +def collect_results_cpu(result_part, size, tmpdir=None): + """Collect results with CPU.""" + rank, world_size = get_dist_info() + # create a tmp dir if it is not specified + if tmpdir is None: + MAX_LEN = 512 + # 32 is whitespace + dir_tensor = torch.full((MAX_LEN, ), + 32, + dtype=torch.uint8, + device='cuda') + if rank == 0: + tmpdir = tempfile.mkdtemp() + tmpdir = torch.tensor( + bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda') + dir_tensor[:len(tmpdir)] = tmpdir + dist.broadcast(dir_tensor, 0) + tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() + else: + mmcv.mkdir_or_exist(tmpdir) + # dump the part result to the dir + mmcv.dump(result_part, osp.join(tmpdir, 'part_{}.pkl'.format(rank))) + dist.barrier() + # collect all parts + if rank != 0: + return None + else: + # load results of all parts from tmp dir + part_list = [] + for i in range(world_size): + part_file = osp.join(tmpdir, 'part_{}.pkl'.format(i)) + part_list.append(mmcv.load(part_file)) + # sort the results + ordered_results = [] + for res in zip(*part_list): + ordered_results.extend(list(res)) + # the dataloader may pad some samples + ordered_results = ordered_results[:size] + # remove tmp dir + shutil.rmtree(tmpdir) + return ordered_results + + +def collect_results_gpu(result_part, size): + """Collect results with GPU.""" + rank, world_size = get_dist_info() + # dump result part to tensor with pickle + part_tensor = torch.tensor( + bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda') + # gather all result part tensor shape + shape_tensor = torch.tensor(part_tensor.shape, device='cuda') + shape_list = [shape_tensor.clone() for _ in range(world_size)] + dist.all_gather(shape_list, shape_tensor) + # padding result part tensor to max length + shape_max = torch.tensor(shape_list).max() + part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda') + part_send[:shape_tensor[0]] = part_tensor + part_recv_list = [ + part_tensor.new_zeros(shape_max) for _ in range(world_size) + ] + # gather all result part + dist.all_gather(part_recv_list, part_send) + + if rank == 0: + part_list = [] + for recv, shape in zip(part_recv_list, shape_list): + part_list.append( + pickle.loads(recv[:shape[0]].cpu().numpy().tobytes())) + # sort the results + ordered_results = [] + for res in zip(*part_list): + ordered_results.extend(list(res)) + # the dataloader may pad some samples + ordered_results = ordered_results[:size] + return ordered_results diff --git a/models/mmseg/apis/train.py b/models/mmseg/apis/train.py new file mode 100644 index 0000000000000000000000000000000000000000..f8ece069b10102ad921d58912d59f077e71e9ea4 --- /dev/null +++ b/models/mmseg/apis/train.py @@ -0,0 +1,115 @@ +import random +import warnings + +import numpy as np +import torch +from mmcv.parallel import MMDataParallel, MMDistributedDataParallel +from mmcv.runner import build_optimizer, build_runner + +from mmseg.core import DistEvalHook, EvalHook +from mmseg.datasets import build_dataloader, build_dataset +from mmseg.utils import get_root_logger + + +def set_random_seed(seed, deterministic=False): + """Set random seed. + Args: + seed (int): Seed to be used. + deterministic (bool): Whether to set the deterministic option for + CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` + to True and `torch.backends.cudnn.benchmark` to False. + Default: False. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + if deterministic: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def train_segmentor(model, + dataset, + cfg, + distributed=False, + validate=False, + timestamp=None, + meta=None): + """Launch segmentor training.""" + logger = get_root_logger(cfg.log_level) + + # prepare data loaders + dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] + data_loaders = [ + build_dataloader( + ds, + cfg.data.samples_per_gpu, + cfg.data.workers_per_gpu, + # cfg.gpus will be ignored if distributed + len(cfg.gpu_ids), + dist=distributed, + seed=cfg.seed, + drop_last=True) for ds in dataset + ] + + # put model on gpus + if distributed: + find_unused_parameters = cfg.get('find_unused_parameters', False) + # Sets the `find_unused_parameters` parameter in + # torch.nn.parallel.DistributedDataParallel + model = MMDistributedDataParallel( + model.cuda(), + device_ids=[torch.cuda.current_device()], + broadcast_buffers=False, + find_unused_parameters=find_unused_parameters) + else: + model = MMDataParallel( + model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) + + # build runner + optimizer = build_optimizer(model, cfg.optimizer) + + if cfg.get('runner') is None: + cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters} + warnings.warn( + 'config is now expected to have a `runner` section, ' + 'please set `runner` in your config.', UserWarning) + + runner = build_runner( + cfg.runner, + default_args=dict( + model=model, + batch_processor=None, + optimizer=optimizer, + work_dir=cfg.work_dir, + logger=logger, + meta=meta)) + + # register hooks + runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config, + cfg.checkpoint_config, cfg.log_config, + cfg.get('momentum_config', None)) + + # an ugly walkaround to make the .log and .log.json filenames the same + runner.timestamp = timestamp + + # register eval hooks + if validate: + val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) + val_dataloader = build_dataloader( + val_dataset, + samples_per_gpu=1, + workers_per_gpu=cfg.data.workers_per_gpu, + dist=distributed, + shuffle=False) + eval_cfg = cfg.get('evaluation', {}) + eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' + eval_hook = DistEvalHook if distributed else EvalHook + runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) + + if cfg.resume_from: + runner.resume(cfg.resume_from) + elif cfg.load_from: + runner.load_checkpoint(cfg.load_from) + runner.run(data_loaders, cfg.workflow) \ No newline at end of file diff --git a/models/mmseg/core/__init__.py b/models/mmseg/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..965605587211b7bf0bd6bc3acdbb33dd49cab023 --- /dev/null +++ b/models/mmseg/core/__init__.py @@ -0,0 +1,3 @@ +from .evaluation import * # noqa: F401, F403 +from .seg import * # noqa: F401, F403 +from .utils import * # noqa: F401, F403 diff --git a/models/mmseg/core/evaluation/__init__.py b/models/mmseg/core/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c58d926f06b3225f0ac93e10a3766ece6b6a8b2a --- /dev/null +++ b/models/mmseg/core/evaluation/__init__.py @@ -0,0 +1,8 @@ +from .class_names import get_classes, get_palette +from .eval_hooks import DistEvalHook, EvalHook +from .metrics import eval_metrics, mean_dice, mean_iou + +__all__ = [ + 'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'eval_metrics', + 'get_classes', 'get_palette' +] diff --git a/models/mmseg/core/evaluation/class_names.py b/models/mmseg/core/evaluation/class_names.py new file mode 100644 index 0000000000000000000000000000000000000000..0d8e66d54b47c200d969ec9fb0bbb642be5d12c3 --- /dev/null +++ b/models/mmseg/core/evaluation/class_names.py @@ -0,0 +1,152 @@ +import mmcv + + +def cityscapes_classes(): + """Cityscapes class names for external use.""" + return [ + 'road', 'sidewalk', 'building', 'wall', 'fence', 'pole', + 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', + 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', + 'bicycle' + ] + + +def ade_classes(): + """ADE20K class names for external use.""" + return [ + 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ', + 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth', + 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', + 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', + 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', + 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column', + 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', + 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path', + 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', + 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table', + 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', + 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', + 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', + 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver', + 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister', + 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', + 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', + 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent', + 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank', + 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', + 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce', + 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen', + 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', + 'clock', 'flag' + ] + + +def voc_classes(): + """Pascal VOC class names for external use.""" + return [ + 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', + 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', + 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', + 'tvmonitor' + ] + + +def cityscapes_palette(): + """Cityscapes palette for external use.""" + return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], + [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], + [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60], + [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100], + [0, 0, 230], [119, 11, 32]] + + +def ade_palette(): + """ADE20K palette for external use.""" + return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], + [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], + [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], + [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], + [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], + [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], + [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], + [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], + [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], + [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], + [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], + [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], + [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], + [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], + [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], + [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], + [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], + [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], + [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], + [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], + [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], + [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], + [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], + [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], + [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], + [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], + [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], + [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], + [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], + [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], + [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], + [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], + [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], + [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], + [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], + [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], + [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], + [102, 255, 0], [92, 0, 255]] + + +def voc_palette(): + """Pascal VOC palette for external use.""" + return [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], + [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], + [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], + [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], + [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]] + + +dataset_aliases = { + 'cityscapes': ['cityscapes'], + 'ade': ['ade', 'ade20k'], + 'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug'] +} + + +def get_classes(dataset): + """Get class names of a dataset.""" + alias2name = {} + for name, aliases in dataset_aliases.items(): + for alias in aliases: + alias2name[alias] = name + + if mmcv.is_str(dataset): + if dataset in alias2name: + labels = eval(alias2name[dataset] + '_classes()') + else: + raise ValueError(f'Unrecognized dataset: {dataset}') + else: + raise TypeError(f'dataset must a str, but got {type(dataset)}') + return labels + + +def get_palette(dataset): + """Get class palette (RGB) of a dataset.""" + alias2name = {} + for name, aliases in dataset_aliases.items(): + for alias in aliases: + alias2name[alias] = name + + if mmcv.is_str(dataset): + if dataset in alias2name: + labels = eval(alias2name[dataset] + '_palette()') + else: + raise ValueError(f'Unrecognized dataset: {dataset}') + else: + raise TypeError(f'dataset must a str, but got {type(dataset)}') + return labels diff --git a/models/mmseg/core/evaluation/eval_hooks.py b/models/mmseg/core/evaluation/eval_hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..09c6265ece2d8a4b9656dc0f07336a059f20729b --- /dev/null +++ b/models/mmseg/core/evaluation/eval_hooks.py @@ -0,0 +1,107 @@ +import os.path as osp + +from mmcv.runner import Hook +from torch.utils.data import DataLoader + + +class EvalHook(Hook): + """Evaluation hook. + + Attributes: + dataloader (DataLoader): A PyTorch dataloader. + interval (int): Evaluation interval (by epochs). Default: 1. + """ + + def __init__(self, dataloader, interval=1, by_epoch=False, **eval_kwargs): + if not isinstance(dataloader, DataLoader): + raise TypeError('dataloader must be a pytorch DataLoader, but got ' + f'{type(dataloader)}') + self.dataloader = dataloader + self.interval = interval + self.by_epoch = by_epoch + self.eval_kwargs = eval_kwargs + + def after_train_iter(self, runner): + """After train epoch hook.""" + if self.by_epoch or not self.every_n_iters(runner, self.interval): + return + from mmseg.apis import single_gpu_test + runner.log_buffer.clear() + results = single_gpu_test(runner.model, self.dataloader, show=False) + self.evaluate(runner, results) + + def after_train_epoch(self, runner): + """After train epoch hook.""" + if not self.by_epoch or not self.every_n_epochs(runner, self.interval): + return + from mmseg.apis import single_gpu_test + runner.log_buffer.clear() + results = single_gpu_test(runner.model, self.dataloader, show=False) + self.evaluate(runner, results) + + def evaluate(self, runner, results): + """Call evaluate function of dataset.""" + eval_res = self.dataloader.dataset.evaluate( + results, logger=runner.logger, **self.eval_kwargs) + for name, val in eval_res.items(): + runner.log_buffer.output[name] = val + runner.log_buffer.ready = True + + +class DistEvalHook(EvalHook): + """Distributed evaluation hook. + + Attributes: + dataloader (DataLoader): A PyTorch dataloader. + interval (int): Evaluation interval (by epochs). Default: 1. + tmpdir (str | None): Temporary directory to save the results of all + processes. Default: None. + gpu_collect (bool): Whether to use gpu or cpu to collect results. + Default: False. + """ + + def __init__(self, + dataloader, + interval=1, + gpu_collect=False, + by_epoch=False, + **eval_kwargs): + if not isinstance(dataloader, DataLoader): + raise TypeError( + 'dataloader must be a pytorch DataLoader, but got {}'.format( + type(dataloader))) + self.dataloader = dataloader + self.interval = interval + self.gpu_collect = gpu_collect + self.by_epoch = by_epoch + self.eval_kwargs = eval_kwargs + + def after_train_iter(self, runner): + """After train epoch hook.""" + if self.by_epoch or not self.every_n_iters(runner, self.interval): + return + from mmseg.apis import multi_gpu_test + runner.log_buffer.clear() + results = multi_gpu_test( + runner.model, + self.dataloader, + tmpdir=osp.join(runner.work_dir, '.eval_hook'), + gpu_collect=self.gpu_collect) + if runner.rank == 0: + print('\n') + self.evaluate(runner, results) + + def after_train_epoch(self, runner): + """After train epoch hook.""" + if not self.by_epoch or not self.every_n_epochs(runner, self.interval): + return + from mmseg.apis import multi_gpu_test + runner.log_buffer.clear() + results = multi_gpu_test( + runner.model, + self.dataloader, + tmpdir=osp.join(runner.work_dir, '.eval_hook'), + gpu_collect=self.gpu_collect) + if runner.rank == 0: + print('\n') + self.evaluate(runner, results) diff --git a/models/mmseg/core/evaluation/metrics.py b/models/mmseg/core/evaluation/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..95b096e7a3d0034c09b9116f7d5614bb4e2601b1 --- /dev/null +++ b/models/mmseg/core/evaluation/metrics.py @@ -0,0 +1,229 @@ +import mmcv +import numpy as np + + +def intersect_and_union(pred_label, + label, + num_classes, + ignore_index, + label_map=dict(), + reduce_zero_label=False): + """Calculate intersection and Union. + + Args: + pred_label (ndarray): Prediction segmentation map. + label (ndarray): Ground truth segmentation map. + num_classes (int): Number of categories. + ignore_index (int): Index that will be ignored in evaluation. + label_map (dict): Mapping old labels to new labels. The parameter will + work only when label is str. Default: dict(). + reduce_zero_label (bool): Wether ignore zero label. The parameter will + work only when label is str. Default: False. + + Returns: + ndarray: The intersection of prediction and ground truth histogram + on all classes. + ndarray: The union of prediction and ground truth histogram on all + classes. + ndarray: The prediction histogram on all classes. + ndarray: The ground truth histogram on all classes. + """ + + if isinstance(pred_label, str): + pred_label = np.load(pred_label) + + if isinstance(label, str): + label = mmcv.imread(label, flag='unchanged', backend='pillow') + # modify if custom classes + if label_map is not None: + for old_id, new_id in label_map.items(): + label[label == old_id] = new_id + if reduce_zero_label: + # avoid using underflow conversion + label[label == 0] = 255 + label = label - 1 + label[label == 254] = 255 + + mask = (label != ignore_index) + pred_label = pred_label[mask] + label = label[mask] + + intersect = pred_label[pred_label == label] + area_intersect, _ = np.histogram( + intersect, bins=np.arange(num_classes + 1)) + area_pred_label, _ = np.histogram( + pred_label, bins=np.arange(num_classes + 1)) + area_label, _ = np.histogram(label, bins=np.arange(num_classes + 1)) + area_union = area_pred_label + area_label - area_intersect + + return area_intersect, area_union, area_pred_label, area_label + + +def total_intersect_and_union(results, + gt_seg_maps, + num_classes, + ignore_index, + label_map=dict(), + reduce_zero_label=False): + """Calculate Total Intersection and Union. + + Args: + results (list[ndarray]): List of prediction segmentation maps. + gt_seg_maps (list[ndarray]): list of ground truth segmentation maps. + num_classes (int): Number of categories. + ignore_index (int): Index that will be ignored in evaluation. + label_map (dict): Mapping old labels to new labels. Default: dict(). + reduce_zero_label (bool): Wether ignore zero label. Default: False. + + Returns: + ndarray: The intersection of prediction and ground truth histogram + on all classes. + ndarray: The union of prediction and ground truth histogram on all + classes. + ndarray: The prediction histogram on all classes. + ndarray: The ground truth histogram on all classes. + """ + + num_imgs = len(results) + assert len(gt_seg_maps) == num_imgs + total_area_intersect = np.zeros((num_classes, ), dtype=np.float) + total_area_union = np.zeros((num_classes, ), dtype=np.float) + total_area_pred_label = np.zeros((num_classes, ), dtype=np.float) + total_area_label = np.zeros((num_classes, ), dtype=np.float) + for i in range(num_imgs): + area_intersect, area_union, area_pred_label, area_label = \ + intersect_and_union(results[i], gt_seg_maps[i], num_classes, + ignore_index, label_map, reduce_zero_label) + total_area_intersect += area_intersect + total_area_union += area_union + total_area_pred_label += area_pred_label + total_area_label += area_label + return total_area_intersect, total_area_union, \ + total_area_pred_label, total_area_label + + +def mean_iou(results, + gt_seg_maps, + num_classes, + ignore_index, + nan_to_num=None, + label_map=dict(), + reduce_zero_label=False): + """Calculate Mean Intersection and Union (mIoU) + + Args: + results (list[ndarray]): List of prediction segmentation maps. + gt_seg_maps (list[ndarray]): list of ground truth segmentation maps. + num_classes (int): Number of categories. + ignore_index (int): Index that will be ignored in evaluation. + nan_to_num (int, optional): If specified, NaN values will be replaced + by the numbers defined by the user. Default: None. + label_map (dict): Mapping old labels to new labels. Default: dict(). + reduce_zero_label (bool): Wether ignore zero label. Default: False. + + Returns: + float: Overall accuracy on all images. + ndarray: Per category accuracy, shape (num_classes, ). + ndarray: Per category IoU, shape (num_classes, ). + """ + + all_acc, acc, iou = eval_metrics( + results=results, + gt_seg_maps=gt_seg_maps, + num_classes=num_classes, + ignore_index=ignore_index, + metrics=['mIoU'], + nan_to_num=nan_to_num, + label_map=label_map, + reduce_zero_label=reduce_zero_label) + return all_acc, acc, iou + + +def mean_dice(results, + gt_seg_maps, + num_classes, + ignore_index, + nan_to_num=None, + label_map=dict(), + reduce_zero_label=False): + """Calculate Mean Dice (mDice) + + Args: + results (list[ndarray]): List of prediction segmentation maps. + gt_seg_maps (list[ndarray]): list of ground truth segmentation maps. + num_classes (int): Number of categories. + ignore_index (int): Index that will be ignored in evaluation. + nan_to_num (int, optional): If specified, NaN values will be replaced + by the numbers defined by the user. Default: None. + label_map (dict): Mapping old labels to new labels. Default: dict(). + reduce_zero_label (bool): Wether ignore zero label. Default: False. + + Returns: + float: Overall accuracy on all images. + ndarray: Per category accuracy, shape (num_classes, ). + ndarray: Per category dice, shape (num_classes, ). + """ + + all_acc, acc, dice = eval_metrics( + results=results, + gt_seg_maps=gt_seg_maps, + num_classes=num_classes, + ignore_index=ignore_index, + metrics=['mDice'], + nan_to_num=nan_to_num, + label_map=label_map, + reduce_zero_label=reduce_zero_label) + return all_acc, acc, dice + + +def eval_metrics(results, + gt_seg_maps, + num_classes, + ignore_index, + metrics=['mIoU'], + nan_to_num=None, + label_map=dict(), + reduce_zero_label=False): + """Calculate evaluation metrics + Args: + results (list[ndarray]): List of prediction segmentation maps. + gt_seg_maps (list[ndarray]): list of ground truth segmentation maps. + num_classes (int): Number of categories. + ignore_index (int): Index that will be ignored in evaluation. + metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'. + nan_to_num (int, optional): If specified, NaN values will be replaced + by the numbers defined by the user. Default: None. + label_map (dict): Mapping old labels to new labels. Default: dict(). + reduce_zero_label (bool): Wether ignore zero label. Default: False. + Returns: + float: Overall accuracy on all images. + ndarray: Per category accuracy, shape (num_classes, ). + ndarray: Per category evalution metrics, shape (num_classes, ). + """ + + if isinstance(metrics, str): + metrics = [metrics] + allowed_metrics = ['mIoU', 'mDice'] + if not set(metrics).issubset(set(allowed_metrics)): + raise KeyError('metrics {} is not supported'.format(metrics)) + total_area_intersect, total_area_union, total_area_pred_label, \ + total_area_label = total_intersect_and_union(results, gt_seg_maps, + num_classes, ignore_index, + label_map, + reduce_zero_label) + all_acc = total_area_intersect.sum() / total_area_label.sum() + acc = total_area_intersect / total_area_label + ret_metrics = [all_acc, acc] + for metric in metrics: + if metric == 'mIoU': + iou = total_area_intersect / total_area_union + ret_metrics.append(iou) + elif metric == 'mDice': + dice = 2 * total_area_intersect / ( + total_area_pred_label + total_area_label) + ret_metrics.append(dice) + if nan_to_num is not None: + ret_metrics = [ + np.nan_to_num(metric, nan=nan_to_num) for metric in ret_metrics + ] + return ret_metrics diff --git a/models/mmseg/core/seg/__init__.py b/models/mmseg/core/seg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..93bc129b685e4a3efca2cc891729981b2865900d --- /dev/null +++ b/models/mmseg/core/seg/__init__.py @@ -0,0 +1,4 @@ +from .builder import build_pixel_sampler +from .sampler import BasePixelSampler, OHEMPixelSampler + +__all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler'] diff --git a/models/mmseg/core/seg/builder.py b/models/mmseg/core/seg/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..f5a117ce7b221f749d9def4b8ed97c07fe90b6e3 --- /dev/null +++ b/models/mmseg/core/seg/builder.py @@ -0,0 +1,8 @@ +from mmcv.utils import Registry, build_from_cfg + +PIXEL_SAMPLERS = Registry('pixel sampler') + + +def build_pixel_sampler(cfg, **default_args): + """Build pixel sampler for segmentation map.""" + return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args) diff --git a/models/mmseg/core/seg/sampler/__init__.py b/models/mmseg/core/seg/sampler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..332b242c03d1c5e80d4577df442a9a037b1816e1 --- /dev/null +++ b/models/mmseg/core/seg/sampler/__init__.py @@ -0,0 +1,4 @@ +from .base_pixel_sampler import BasePixelSampler +from .ohem_pixel_sampler import OHEMPixelSampler + +__all__ = ['BasePixelSampler', 'OHEMPixelSampler'] diff --git a/models/mmseg/core/seg/sampler/base_pixel_sampler.py b/models/mmseg/core/seg/sampler/base_pixel_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..db322d199ff809ef1c7962a8c979563386201549 --- /dev/null +++ b/models/mmseg/core/seg/sampler/base_pixel_sampler.py @@ -0,0 +1,13 @@ +from abc import ABCMeta, abstractmethod + + +class BasePixelSampler(metaclass=ABCMeta): + """Base class of pixel sampler.""" + + def __init__(self, **kwargs): + pass + + @abstractmethod + def sample(self, seg_logit, seg_label): + """Placeholder for sample function.""" + pass diff --git a/models/mmseg/core/seg/sampler/ohem_pixel_sampler.py b/models/mmseg/core/seg/sampler/ohem_pixel_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..88bb10d44026ba9f21756eaea9e550841cd59b9f --- /dev/null +++ b/models/mmseg/core/seg/sampler/ohem_pixel_sampler.py @@ -0,0 +1,76 @@ +import torch +import torch.nn.functional as F + +from ..builder import PIXEL_SAMPLERS +from .base_pixel_sampler import BasePixelSampler + + +@PIXEL_SAMPLERS.register_module() +class OHEMPixelSampler(BasePixelSampler): + """Online Hard Example Mining Sampler for segmentation. + + Args: + context (nn.Module): The context of sampler, subclass of + :obj:`BaseDecodeHead`. + thresh (float, optional): The threshold for hard example selection. + Below which, are prediction with low confidence. If not + specified, the hard examples will be pixels of top ``min_kept`` + loss. Default: None. + min_kept (int, optional): The minimum number of predictions to keep. + Default: 100000. + """ + + def __init__(self, context, thresh=None, min_kept=100000): + super(OHEMPixelSampler, self).__init__() + self.context = context + assert min_kept > 1 + self.thresh = thresh + self.min_kept = min_kept + + def sample(self, seg_logit, seg_label): + """Sample pixels that have high loss or with low prediction confidence. + + Args: + seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W) + seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W) + + Returns: + torch.Tensor: segmentation weight, shape (N, H, W) + """ + with torch.no_grad(): + assert seg_logit.shape[2:] == seg_label.shape[2:] + assert seg_label.shape[1] == 1 + seg_label = seg_label.squeeze(1).long() + batch_kept = self.min_kept * seg_label.size(0) + valid_mask = seg_label != self.context.ignore_index + seg_weight = seg_logit.new_zeros(size=seg_label.size()) + valid_seg_weight = seg_weight[valid_mask] + if self.thresh is not None: + seg_prob = F.softmax(seg_logit, dim=1) + + tmp_seg_label = seg_label.clone().unsqueeze(1) + tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0 + seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1) + sort_prob, sort_indices = seg_prob[valid_mask].sort() + + if sort_prob.numel() > 0: + min_threshold = sort_prob[min(batch_kept, + sort_prob.numel() - 1)] + else: + min_threshold = 0.0 + threshold = max(min_threshold, self.thresh) + valid_seg_weight[seg_prob[valid_mask] < threshold] = 1. + else: + losses = self.context.loss_decode( + seg_logit, + seg_label, + weight=None, + ignore_index=self.context.ignore_index, + reduction_override='none') + # faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa + _, sort_indices = losses[valid_mask].sort(descending=True) + valid_seg_weight[sort_indices[:batch_kept]] = 1. + + seg_weight[valid_mask] = valid_seg_weight + + return seg_weight diff --git a/models/mmseg/core/utils/__init__.py b/models/mmseg/core/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f2678b321c295bcceaef945111ac3524be19d6e4 --- /dev/null +++ b/models/mmseg/core/utils/__init__.py @@ -0,0 +1,3 @@ +from .misc import add_prefix + +__all__ = ['add_prefix'] diff --git a/models/mmseg/core/utils/misc.py b/models/mmseg/core/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..eb862a82bd47c8624db3dd5c6fb6ad8a03b62466 --- /dev/null +++ b/models/mmseg/core/utils/misc.py @@ -0,0 +1,17 @@ +def add_prefix(inputs, prefix): + """Add prefix for dict. + + Args: + inputs (dict): The input dict with str keys. + prefix (str): The prefix to add. + + Returns: + + dict: The dict with keys updated with ``prefix``. + """ + + outputs = dict() + for name, value in inputs.items(): + outputs[f'{prefix}.{name}'] = value + + return outputs diff --git a/models/mmseg/datasets/__init__.py b/models/mmseg/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d9180015e743c087c5434efc5e6445ed729e86e8 --- /dev/null +++ b/models/mmseg/datasets/__init__.py @@ -0,0 +1,20 @@ +from .ade import ADE20KDataset +from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset +from .chase_db1 import ChaseDB1Dataset +from .cityscapes import CityscapesDataset +from .custom import CustomDataset +from .dataset_wrappers import ConcatDataset, RepeatDataset +from .drive import DRIVEDataset +from .hrf import HRFDataset +from .pascal_context import PascalContextDataset +from .stare import STAREDataset +from .voc import PascalVOCDataset +from .mapillary import MapillaryDataset +from .cocostuff import CocoStuff + +__all__ = [ + 'CustomDataset', 'build_dataloader', 'ConcatDataset', 'RepeatDataset', + 'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset', + 'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset', + 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'MapillaryDataset', 'CocoStuff' +] diff --git a/models/mmseg/datasets/ade.py b/models/mmseg/datasets/ade.py new file mode 100644 index 0000000000000000000000000000000000000000..5913e43775ed4920b6934c855eb5a37c54218ebf --- /dev/null +++ b/models/mmseg/datasets/ade.py @@ -0,0 +1,84 @@ +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class ADE20KDataset(CustomDataset): + """ADE20K dataset. + + In segmentation map annotation for ADE20K, 0 stands for background, which + is not included in 150 categories. ``reduce_zero_label`` is fixed to True. + The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to + '.png'. + """ + CLASSES = ( + 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ', + 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth', + 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', + 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', + 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', + 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column', + 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', + 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path', + 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', + 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table', + 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', + 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', + 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', + 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver', + 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister', + 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', + 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', + 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent', + 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank', + 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', + 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce', + 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen', + 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', + 'clock', 'flag') + + PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], + [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], + [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], + [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], + [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], + [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], + [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], + [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], + [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], + [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], + [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], + [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], + [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], + [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], + [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], + [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], + [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], + [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], + [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], + [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], + [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], + [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], + [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], + [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], + [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], + [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], + [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], + [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], + [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], + [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], + [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], + [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], + [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], + [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], + [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], + [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], + [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], + [102, 255, 0], [92, 0, 255]] + + def __init__(self, **kwargs): + super(ADE20KDataset, self).__init__( + img_suffix='.jpg', + seg_map_suffix='.png', + reduce_zero_label=True, + **kwargs) diff --git a/models/mmseg/datasets/builder.py b/models/mmseg/datasets/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..f7a9926111cad3c8ab140ab8d289dbc66053517a --- /dev/null +++ b/models/mmseg/datasets/builder.py @@ -0,0 +1,169 @@ +import copy +import platform +import random +from functools import partial + +import numpy as np +from mmcv.parallel import collate +from mmcv.runner import get_dist_info +from mmcv.utils import Registry, build_from_cfg +from mmcv.utils.parrots_wrapper import DataLoader, PoolDataLoader +from torch.utils.data import DistributedSampler + +if platform.system() != 'Windows': + # https://github.com/pytorch/pytorch/issues/973 + import resource + rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) + hard_limit = rlimit[1] + soft_limit = min(4096, hard_limit) + resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit)) + +DATASETS = Registry('dataset') +PIPELINES = Registry('pipeline') + + +def _concat_dataset(cfg, default_args=None): + """Build :obj:`ConcatDataset by.""" + from .dataset_wrappers import ConcatDataset + img_dir = cfg['img_dir'] + ann_dir = cfg.get('ann_dir', None) + split = cfg.get('split', None) + num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1 + if ann_dir is not None: + num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1 + else: + num_ann_dir = 0 + if split is not None: + num_split = len(split) if isinstance(split, (list, tuple)) else 1 + else: + num_split = 0 + if num_img_dir > 1: + assert num_img_dir == num_ann_dir or num_ann_dir == 0 + assert num_img_dir == num_split or num_split == 0 + else: + assert num_split == num_ann_dir or num_ann_dir <= 1 + num_dset = max(num_split, num_img_dir) + + datasets = [] + for i in range(num_dset): + data_cfg = copy.deepcopy(cfg) + if isinstance(img_dir, (list, tuple)): + data_cfg['img_dir'] = img_dir[i] + if isinstance(ann_dir, (list, tuple)): + data_cfg['ann_dir'] = ann_dir[i] + if isinstance(split, (list, tuple)): + data_cfg['split'] = split[i] + datasets.append(build_dataset(data_cfg, default_args)) + + return ConcatDataset(datasets) + + +def build_dataset(cfg, default_args=None): + """Build datasets.""" + from .dataset_wrappers import ConcatDataset, RepeatDataset + if isinstance(cfg, (list, tuple)): + dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg]) + elif cfg['type'] == 'RepeatDataset': + dataset = RepeatDataset( + build_dataset(cfg['dataset'], default_args), cfg['times']) + elif isinstance(cfg.get('img_dir'), (list, tuple)) or isinstance( + cfg.get('split', None), (list, tuple)): + dataset = _concat_dataset(cfg, default_args) + else: + dataset = build_from_cfg(cfg, DATASETS, default_args) + + return dataset + + +def build_dataloader(dataset, + samples_per_gpu, + workers_per_gpu, + num_gpus=1, + dist=True, + shuffle=True, + seed=None, + drop_last=False, + pin_memory=True, + dataloader_type='PoolDataLoader', + **kwargs): + """Build PyTorch DataLoader. + + In distributed training, each GPU/process has a dataloader. + In non-distributed training, there is only one dataloader for all GPUs. + + Args: + dataset (Dataset): A PyTorch dataset. + samples_per_gpu (int): Number of training samples on each GPU, i.e., + batch size of each GPU. + workers_per_gpu (int): How many subprocesses to use for data loading + for each GPU. + num_gpus (int): Number of GPUs. Only used in non-distributed training. + dist (bool): Distributed training/test or not. Default: True. + shuffle (bool): Whether to shuffle the data at every epoch. + Default: True. + seed (int | None): Seed to be used. Default: None. + drop_last (bool): Whether to drop the last incomplete batch in epoch. + Default: False + pin_memory (bool): Whether to use pin_memory in DataLoader. + Default: True + dataloader_type (str): Type of dataloader. Default: 'PoolDataLoader' + kwargs: any keyword argument to be used to initialize DataLoader + + Returns: + DataLoader: A PyTorch dataloader. + """ + rank, world_size = get_dist_info() + if dist: + sampler = DistributedSampler( + dataset, world_size, rank, shuffle=shuffle) + shuffle = False + batch_size = samples_per_gpu + num_workers = workers_per_gpu + else: + sampler = None + batch_size = num_gpus * samples_per_gpu + num_workers = num_gpus * workers_per_gpu + + init_fn = partial( + worker_init_fn, num_workers=num_workers, rank=rank, + seed=seed) if seed is not None else None + + assert dataloader_type in ( + 'DataLoader', + 'PoolDataLoader'), f'unsupported dataloader {dataloader_type}' + + if dataloader_type == 'PoolDataLoader': + dataloader = PoolDataLoader + elif dataloader_type == 'DataLoader': + dataloader = DataLoader + + data_loader = dataloader( + dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=num_workers, + collate_fn=partial(collate, samples_per_gpu=samples_per_gpu), + pin_memory=pin_memory, + shuffle=shuffle, + worker_init_fn=init_fn, + drop_last=drop_last, + **kwargs) + + return data_loader + + +def worker_init_fn(worker_id, num_workers, rank, seed): + """Worker init func for dataloader. + + The seed of each worker equals to num_worker * rank + worker_id + user_seed + + Args: + worker_id (int): Worker id. + num_workers (int): Number of workers. + rank (int): The rank of current process. + seed (int): The random seed to use. + """ + + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) diff --git a/models/mmseg/datasets/chase_db1.py b/models/mmseg/datasets/chase_db1.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc29bea14704a4407f83474610cbc3bef32c708 --- /dev/null +++ b/models/mmseg/datasets/chase_db1.py @@ -0,0 +1,27 @@ +import os.path as osp + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class ChaseDB1Dataset(CustomDataset): + """Chase_db1 dataset. + + In segmentation map annotation for Chase_db1, 0 stands for background, + which is included in 2 categories. ``reduce_zero_label`` is fixed to False. + The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to + '_1stHO.png'. + """ + + CLASSES = ('background', 'vessel') + + PALETTE = [[120, 120, 120], [6, 230, 230]] + + def __init__(self, **kwargs): + super(ChaseDB1Dataset, self).__init__( + img_suffix='.png', + seg_map_suffix='_1stHO.png', + reduce_zero_label=False, + **kwargs) + assert osp.exists(self.img_dir) diff --git a/models/mmseg/datasets/cityscapes.py b/models/mmseg/datasets/cityscapes.py new file mode 100644 index 0000000000000000000000000000000000000000..fa9958ac1401644420d264c48cf8d807a44d7cf9 --- /dev/null +++ b/models/mmseg/datasets/cityscapes.py @@ -0,0 +1,217 @@ +import os.path as osp +import tempfile + +import mmcv +import numpy as np +from mmcv.utils import print_log +from PIL import Image + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class CityscapesDataset(CustomDataset): + """Cityscapes dataset. + + The ``img_suffix`` is fixed to '_leftImg8bit.png' and ``seg_map_suffix`` is + fixed to '_gtFine_labelTrainIds.png' for Cityscapes dataset. + """ + + CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole', + 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', + 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', + 'bicycle') + + PALETTE = [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], + [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], + [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60], + [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], + [0, 80, 100], [0, 0, 230], [119, 11, 32]] + + def __init__(self, **kwargs): + super(CityscapesDataset, self).__init__( + img_suffix='_leftImg8bit.png', + seg_map_suffix='_gtFine_labelTrainIds.png', + **kwargs) + + @staticmethod + def _convert_to_label_id(result): + """Convert trainId to id for cityscapes.""" + if isinstance(result, str): + result = np.load(result) + import cityscapesscripts.helpers.labels as CSLabels + result_copy = result.copy() + for trainId, label in CSLabels.trainId2label.items(): + result_copy[result == trainId] = label.id + + return result_copy + + def results2img(self, results, imgfile_prefix, to_label_id): + """Write the segmentation results to images. + + Args: + results (list[list | tuple | ndarray]): Testing results of the + dataset. + imgfile_prefix (str): The filename prefix of the png files. + If the prefix is "somepath/xxx", + the png files will be named "somepath/xxx.png". + to_label_id (bool): whether convert output to label_id for + submission + + Returns: + list[str: str]: result txt files which contains corresponding + semantic segmentation images. + """ + mmcv.mkdir_or_exist(imgfile_prefix) + result_files = [] + prog_bar = mmcv.ProgressBar(len(self)) + for idx in range(len(self)): + result = results[idx] + if to_label_id: + result = self._convert_to_label_id(result) + filename = self.img_infos[idx]['filename'] + basename = osp.splitext(osp.basename(filename))[0] + + png_filename = osp.join(imgfile_prefix, f'{basename}.png') + + output = Image.fromarray(result.astype(np.uint8)).convert('P') + import cityscapesscripts.helpers.labels as CSLabels + palette = np.zeros((len(CSLabels.id2label), 3), dtype=np.uint8) + for label_id, label in CSLabels.id2label.items(): + palette[label_id] = label.color + + output.putpalette(palette) + output.save(png_filename) + result_files.append(png_filename) + prog_bar.update() + + return result_files + + def format_results(self, results, imgfile_prefix=None, to_label_id=True): + """Format the results into dir (standard format for Cityscapes + evaluation). + + Args: + results (list): Testing results of the dataset. + imgfile_prefix (str | None): The prefix of images files. It + includes the file path and the prefix of filename, e.g., + "a/b/prefix". If not specified, a temp file will be created. + Default: None. + to_label_id (bool): whether convert output to label_id for + submission. Default: False + + Returns: + tuple: (result_files, tmp_dir), result_files is a list containing + the image paths, tmp_dir is the temporal directory created + for saving json/png files when img_prefix is not specified. + """ + + assert isinstance(results, list), 'results must be a list' + assert len(results) == len(self), ( + 'The length of results is not equal to the dataset len: ' + f'{len(results)} != {len(self)}') + + if imgfile_prefix is None: + tmp_dir = tempfile.TemporaryDirectory() + imgfile_prefix = tmp_dir.name + else: + tmp_dir = None + result_files = self.results2img(results, imgfile_prefix, to_label_id) + + return result_files, tmp_dir + + def evaluate(self, + results, + metric='mIoU', + logger=None, + imgfile_prefix=None, + efficient_test=False): + """Evaluation in Cityscapes/default protocol. + + Args: + results (list): Testing results of the dataset. + metric (str | list[str]): Metrics to be evaluated. + logger (logging.Logger | None | str): Logger used for printing + related information during evaluation. Default: None. + imgfile_prefix (str | None): The prefix of output image file, + for cityscapes evaluation only. It includes the file path and + the prefix of filename, e.g., "a/b/prefix". + If results are evaluated with cityscapes protocol, it would be + the prefix of output png files. The output files would be + png images under folder "a/b/prefix/xxx.png", where "xxx" is + the image name of cityscapes. If not specified, a temp file + will be created for evaluation. + Default: None. + + Returns: + dict[str, float]: Cityscapes/default metrics. + """ + + eval_results = dict() + metrics = metric.copy() if isinstance(metric, list) else [metric] + if 'cityscapes' in metrics: + eval_results.update( + self._evaluate_cityscapes(results, logger, imgfile_prefix)) + metrics.remove('cityscapes') + if len(metrics) > 0: + eval_results.update( + super(CityscapesDataset, + self).evaluate(results, metrics, logger, efficient_test)) + + return eval_results + + def _evaluate_cityscapes(self, results, logger, imgfile_prefix): + """Evaluation in Cityscapes protocol. + + Args: + results (list): Testing results of the dataset. + logger (logging.Logger | str | None): Logger used for printing + related information during evaluation. Default: None. + imgfile_prefix (str | None): The prefix of output image file + + Returns: + dict[str: float]: Cityscapes evaluation results. + """ + try: + import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa + except ImportError: + raise ImportError('Please run "pip install cityscapesscripts" to ' + 'install cityscapesscripts first.') + msg = 'Evaluating in Cityscapes style' + if logger is None: + msg = '\n' + msg + print_log(msg, logger=logger) + + result_files, tmp_dir = self.format_results(results, imgfile_prefix) + + if tmp_dir is None: + result_dir = imgfile_prefix + else: + result_dir = tmp_dir.name + + eval_results = dict() + print_log(f'Evaluating results under {result_dir} ...', logger=logger) + + CSEval.args.evalInstLevelScore = True + CSEval.args.predictionPath = osp.abspath(result_dir) + CSEval.args.evalPixelAccuracy = True + CSEval.args.JSONOutput = False + + seg_map_list = [] + pred_list = [] + + # when evaluating with official cityscapesscripts, + # **_gtFine_labelIds.png is used + for seg_map in mmcv.scandir( + self.ann_dir, 'gtFine_labelIds.png', recursive=True): + seg_map_list.append(osp.join(self.ann_dir, seg_map)) + pred_list.append(CSEval.getPrediction(CSEval.args, seg_map)) + + eval_results.update( + CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args)) + + if tmp_dir is not None: + tmp_dir.cleanup() + + return eval_results diff --git a/models/mmseg/datasets/cocostuff.py b/models/mmseg/datasets/cocostuff.py new file mode 100644 index 0000000000000000000000000000000000000000..b77e756476f7ea86829ae8d0bd972fd4535cdb36 --- /dev/null +++ b/models/mmseg/datasets/cocostuff.py @@ -0,0 +1,204 @@ +from .builder import DATASETS +from .custom import CustomDataset +from IPython import embed + +@DATASETS.register_module() +class CocoStuff(CustomDataset): + """Coco Stuff dataset. + """ + nclass = 182 + CLASSES = [str(i) for i in range(nclass)] + + # random generated color + PALETTE = [ + [167, 200, 7], + [127, 228, 215], + [26, 135, 248], + [238, 73, 166], + [91, 210, 215], + [122, 20, 236], + [234, 173, 35], + [34, 98, 46], + [115, 11, 206], + [52, 251, 238], + [209, 156, 236], + [239, 10, 0], + [26, 122, 36], + [162, 181, 66], + [26, 64, 22], + [46, 226, 200], + [89, 176, 6], + [103, 36, 32], + [74, 89, 159], + [250, 215, 25], + [57, 246, 82], + [51, 156, 111], + [139, 114, 219], + [65, 208, 253], + [33, 184, 119], + [230, 239, 58], + [176, 141, 158], + [21, 29, 31], + [135, 133, 163], + [152, 241, 248], + [253, 54, 7], + [231, 86, 229], + [179, 220, 46], + [155, 217, 185], + [58, 251, 190], + [40, 201, 63], + [236, 52, 220], + [71, 203, 170], + [96, 56, 41], + [252, 231, 125], + [255, 60, 100], + [11, 172, 184], + [127, 46, 248], + [1, 105, 163], + [191, 218, 95], + [87, 160, 119], + [149, 223, 79], + [216, 180, 245], + [58, 226, 163], + [11, 43, 118], + [20, 23, 100], + [71, 222, 109], + [124, 197, 150], + [38, 106, 43], + [115, 73, 156], + [113, 110, 50], + [94, 2, 184], + [163, 168, 155], + [83, 39, 145], + [150, 169, 81], + [134, 25, 2], + [145, 49, 138], + [46, 27, 209], + [145, 187, 117], + [197, 9, 211], + [179, 12, 118], + [107, 241, 133], + [255, 176, 224], + [49, 56, 217], + [10, 227, 177], + [152, 117, 25], + [139, 76, 23], + [53, 191, 10], + [14, 244, 90], + [247, 94, 189], + [202, 160, 149], + [24, 31, 150], + [164, 236, 24], + [47, 10, 204], + [84, 187, 44], + [17, 153, 55], + [9, 191, 39], + [216, 53, 216], + [54, 13, 26], + [241, 13, 196], + [157, 90, 225], + [99, 195, 27], + [20, 186, 253], + [175, 192, 0], + [81, 11, 238], + [137, 83, 196], + [53, 186, 24], + [231, 20, 101], + [246, 223, 173], + [75, 202, 249], + [9, 188, 201], + [216, 83, 7], + [152, 92, 54], + [137, 192, 79], + [242, 169, 49], + [99, 65, 207], + [178, 112, 1], + [120, 135, 40], + [71, 220, 82], + [180, 83, 172], + [68, 137, 75], + [46, 58, 15], + [0, 80, 68], + [175, 86, 173], + [19, 208, 152], + [215, 235, 142], + [95, 30, 166], + [246, 193, 8], + [222, 19, 72], + [177, 29, 183], + [238, 61, 178], + [246, 136, 87], + [199, 207, 174], + [218, 149, 231], + [98, 179, 168], + [23, 10, 10], + [223, 9, 253], + [206, 114, 95], + [177, 242, 152], + [115, 189, 142], + [254, 105, 107], + [59, 175, 153], + [42, 114, 178], + [50, 121, 91], + [78, 238, 175], + [232, 201, 123], + [61, 39, 248], + [76, 43, 218], + [121, 191, 38], + [13, 164, 242], + [83, 70, 160], + [109, 2, 64], + [252, 81, 105], + [151, 107, 83], + [31, 95, 170], + [7, 238, 218], + [227, 49, 19], + [56, 102, 49], + [152, 241, 48], + [110, 35, 108], + [59, 198, 242], + [186, 189, 39], + [26, 157, 41], + [183, 16, 169], + [114, 26, 104], + [131, 142, 127], + [118, 85, 219], + [203, 84, 210], + [245, 16, 127], + [57, 238, 110], + [223, 225, 154], + [143, 21, 231], + [12, 215, 113], + [117, 58, 3], + [170, 201, 252], + [60, 190, 197], + [38, 22, 24], + [37, 155, 237], + [175, 41, 211], + [188, 151, 129], + [231, 92, 102], + [229, 112, 245], + [157, 182, 40], + [1, 60, 204], + [57, 58, 19], + [156, 199, 180], + [211, 47, 8], + [153, 115, 233], + [172, 117, 198], + [33, 63, 208], + [107, 80, 154], + [217, 164, 13], + [136, 83, 59], + [53, 206, 6], + [95, 127, 75], + [110, 22, 240], + [244, 212, 2] + ] + + assert len(CLASSES) == len(PALETTE) + + def __init__(self, **kwargs): + super(CocoStuff, self).__init__( + img_suffix='.jpg', + seg_map_suffix='.png', + **kwargs) \ No newline at end of file diff --git a/models/mmseg/datasets/custom.py b/models/mmseg/datasets/custom.py new file mode 100644 index 0000000000000000000000000000000000000000..dc923fb42dfeb9ad04a6b19e5fe1370a93cf99b3 --- /dev/null +++ b/models/mmseg/datasets/custom.py @@ -0,0 +1,380 @@ +import os +import os.path as osp +from functools import reduce + +import mmcv +import numpy as np +from mmcv.utils import print_log +from terminaltables import AsciiTable +from torch.utils.data import Dataset + +from mmseg.core import eval_metrics +from mmseg.utils import get_root_logger +from .builder import DATASETS +from .pipelines import Compose + + +@DATASETS.register_module() +class CustomDataset(Dataset): + """Custom dataset for semantic segmentation. An example of file structure + is as followed. + + .. code-block:: none + + ├── data + │ ├── my_dataset + │ │ ├── img_dir + │ │ │ ├── train + │ │ │ │ ├── xxx{img_suffix} + │ │ │ │ ├── yyy{img_suffix} + │ │ │ │ ├── zzz{img_suffix} + │ │ │ ├── val + │ │ ├── ann_dir + │ │ │ ├── train + │ │ │ │ ├── xxx{seg_map_suffix} + │ │ │ │ ├── yyy{seg_map_suffix} + │ │ │ │ ├── zzz{seg_map_suffix} + │ │ │ ├── val + + The img/gt_semantic_seg pair of CustomDataset should be of the same + except suffix. A valid img/gt_semantic_seg filename pair should be like + ``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included + in the suffix). If split is given, then ``xxx`` is specified in txt file. + Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded. + Please refer to ``docs/tutorials/new_dataset.md`` for more details. + + + Args: + pipeline (list[dict]): Processing pipeline + img_dir (str): Path to image directory + img_suffix (str): Suffix of images. Default: '.jpg' + ann_dir (str, optional): Path to annotation directory. Default: None + seg_map_suffix (str): Suffix of segmentation maps. Default: '.png' + split (str, optional): Split txt file. If split is specified, only + file with suffix in the splits will be loaded. Otherwise, all + images in img_dir/ann_dir will be loaded. Default: None + data_root (str, optional): Data root for img_dir/ann_dir. Default: + None. + test_mode (bool): If test_mode=True, gt wouldn't be loaded. + ignore_index (int): The label index to be ignored. Default: 255 + reduce_zero_label (bool): Whether to mark label zero as ignored. + Default: False + classes (str | Sequence[str], optional): Specify classes to load. + If is None, ``cls.CLASSES`` will be used. Default: None. + palette (Sequence[Sequence[int]]] | np.ndarray | None): + The palette of segmentation map. If None is given, and + self.PALETTE is None, random palette will be generated. + Default: None + """ + + CLASSES = None + + PALETTE = None + + def __init__(self, + pipeline, + img_dir, + img_suffix='.jpg', + ann_dir=None, + seg_map_suffix='.png', + split=None, + data_root=None, + test_mode=False, + ignore_index=255, + reduce_zero_label=False, + classes=None, + palette=None): + self.pipeline = Compose(pipeline) + self.img_dir = img_dir + self.img_suffix = img_suffix + self.ann_dir = ann_dir + self.seg_map_suffix = seg_map_suffix + self.split = split + self.data_root = data_root + self.test_mode = test_mode + self.ignore_index = ignore_index + self.reduce_zero_label = reduce_zero_label + self.label_map = None + self.CLASSES, self.PALETTE = self.get_classes_and_palette( + classes, palette) + + # join paths if data_root is specified + if self.data_root is not None: + if not osp.isabs(self.img_dir): + self.img_dir = osp.join(self.data_root, self.img_dir) + if not (self.ann_dir is None or osp.isabs(self.ann_dir)): + self.ann_dir = osp.join(self.data_root, self.ann_dir) + if not (self.split is None or osp.isabs(self.split)): + self.split = osp.join(self.data_root, self.split) + + # load annotations + self.img_infos = self.load_annotations(self.img_dir, self.img_suffix, + self.ann_dir, + self.seg_map_suffix, self.split) + + def __len__(self): + """Total number of samples of data.""" + return len(self.img_infos) + + def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix, + split): + """Load annotation from directory. + + Args: + img_dir (str): Path to image directory + img_suffix (str): Suffix of images. + ann_dir (str|None): Path to annotation directory. + seg_map_suffix (str|None): Suffix of segmentation maps. + split (str|None): Split txt file. If split is specified, only file + with suffix in the splits will be loaded. Otherwise, all images + in img_dir/ann_dir will be loaded. Default: None + + Returns: + list[dict]: All image info of dataset. + """ + + img_infos = [] + if split is not None: + with open(split) as f: + for line in f: + img_name = line.strip() + img_info = dict(filename=img_name + img_suffix) + if ann_dir is not None: + seg_map = img_name + seg_map_suffix + img_info['ann'] = dict(seg_map=seg_map) + img_infos.append(img_info) + else: + for img in mmcv.scandir(img_dir, img_suffix, recursive=True): + img_info = dict(filename=img) + if ann_dir is not None: + seg_map = img.replace(img_suffix, seg_map_suffix) + img_info['ann'] = dict(seg_map=seg_map) + img_infos.append(img_info) + + print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger()) + return img_infos + + def get_ann_info(self, idx): + """Get annotation by index. + + Args: + idx (int): Index of data. + + Returns: + dict: Annotation info of specified index. + """ + + return self.img_infos[idx]['ann'] + + def pre_pipeline(self, results): + """Prepare results dict for pipeline.""" + results['seg_fields'] = [] + results['img_prefix'] = self.img_dir + results['seg_prefix'] = self.ann_dir + if self.custom_classes: + results['label_map'] = self.label_map + + def __getitem__(self, idx): + """Get training/test data after pipeline. + + Args: + idx (int): Index of data. + + Returns: + dict: Training/test data (with annotation if `test_mode` is set + False). + """ + + if self.test_mode: + return self.prepare_test_img(idx) + else: + return self.prepare_train_img(idx) + + def prepare_train_img(self, idx): + """Get training data and annotations after pipeline. + + Args: + idx (int): Index of data. + + Returns: + dict: Training data and annotation after pipeline with new keys + introduced by pipeline. + """ + + img_info = self.img_infos[idx] + ann_info = self.get_ann_info(idx) + results = dict(img_info=img_info, ann_info=ann_info) + self.pre_pipeline(results) + return self.pipeline(results) + + def prepare_test_img(self, idx): + """Get testing data after pipeline. + + Args: + idx (int): Index of data. + + Returns: + dict: Testing data after pipeline with new keys intorduced by + piepline. + """ + + img_info = self.img_infos[idx] + results = dict(img_info=img_info) + self.pre_pipeline(results) + return self.pipeline(results) + + def format_results(self, results, **kwargs): + """Place holder to format result to dataset specific output.""" + pass + + def get_gt_seg_maps(self, efficient_test=False): + """Get ground truth segmentation maps for evaluation.""" + gt_seg_maps = [] + for img_info in self.img_infos: + seg_map = osp.join(self.ann_dir, img_info['ann']['seg_map']) + if efficient_test: + gt_seg_map = seg_map + else: + gt_seg_map = mmcv.imread( + seg_map, flag='unchanged', backend='pillow') + gt_seg_maps.append(gt_seg_map) + return gt_seg_maps + + def get_classes_and_palette(self, classes=None, palette=None): + """Get class names of current dataset. + + Args: + classes (Sequence[str] | str | None): If classes is None, use + default CLASSES defined by builtin dataset. If classes is a + string, take it as a file name. The file contains the name of + classes where each line contains one class name. If classes is + a tuple or list, override the CLASSES defined by the dataset. + palette (Sequence[Sequence[int]]] | np.ndarray | None): + The palette of segmentation map. If None is given, random + palette will be generated. Default: None + """ + if classes is None: + self.custom_classes = False + return self.CLASSES, self.PALETTE + + self.custom_classes = True + if isinstance(classes, str): + # take it as a file path + class_names = mmcv.list_from_file(classes) + elif isinstance(classes, (tuple, list)): + class_names = classes + else: + raise ValueError(f'Unsupported type {type(classes)} of classes.') + + if self.CLASSES: + if not set(classes).issubset(self.CLASSES): + raise ValueError('classes is not a subset of CLASSES.') + + # dictionary, its keys are the old label ids and its values + # are the new label ids. + # used for changing pixel labels in load_annotations. + self.label_map = {} + for i, c in enumerate(self.CLASSES): + if c not in class_names: + self.label_map[i] = -1 + else: + self.label_map[i] = classes.index(c) + + palette = self.get_palette_for_custom_classes(class_names, palette) + + return class_names, palette + + def get_palette_for_custom_classes(self, class_names, palette=None): + + if self.label_map is not None: + # return subset of palette + palette = [] + for old_id, new_id in sorted( + self.label_map.items(), key=lambda x: x[1]): + if new_id != -1: + palette.append(self.PALETTE[old_id]) + palette = type(self.PALETTE)(palette) + + elif palette is None: + if self.PALETTE is None: + palette = np.random.randint(0, 255, size=(len(class_names), 3)) + else: + palette = self.PALETTE + + return palette + + def evaluate(self, + results, + metric='mIoU', + logger=None, + efficient_test=False, + **kwargs): + """Evaluate the dataset. + + Args: + results (list): Testing results of the dataset. + metric (str | list[str]): Metrics to be evaluated. 'mIoU' and + 'mDice' are supported. + logger (logging.Logger | None | str): Logger used for printing + related information during evaluation. Default: None. + + Returns: + dict[str, float]: Default metrics. + """ + + if isinstance(metric, str): + metric = [metric] + allowed_metrics = ['mIoU', 'mDice'] + if not set(metric).issubset(set(allowed_metrics)): + raise KeyError('metric {} is not supported'.format(metric)) + eval_results = {} + gt_seg_maps = self.get_gt_seg_maps(efficient_test) + if self.CLASSES is None: + num_classes = len( + reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps])) + else: + num_classes = len(self.CLASSES) + ret_metrics = eval_metrics( + results, + gt_seg_maps, + num_classes, + self.ignore_index, + metric, + label_map=self.label_map, + reduce_zero_label=self.reduce_zero_label) + class_table_data = [['Class'] + [m[1:] for m in metric] + ['Acc']] + if self.CLASSES is None: + class_names = tuple(range(num_classes)) + else: + class_names = self.CLASSES + ret_metrics_round = [ + np.round(ret_metric * 100, 2) for ret_metric in ret_metrics + ] + for i in range(num_classes): + class_table_data.append([class_names[i]] + + [m[i] for m in ret_metrics_round[2:]] + + [ret_metrics_round[1][i]]) + summary_table_data = [['Scope'] + + ['m' + head + for head in class_table_data[0][1:]] + ['aAcc']] + ret_metrics_mean = [ + np.round(np.nanmean(ret_metric) * 100, 2) + for ret_metric in ret_metrics + ] + summary_table_data.append(['global'] + ret_metrics_mean[2:] + + [ret_metrics_mean[1]] + + [ret_metrics_mean[0]]) + print_log('per class results:', logger) + table = AsciiTable(class_table_data) + print_log('\n' + table.table, logger=logger) + print_log('Summary:', logger) + table = AsciiTable(summary_table_data) + print_log('\n' + table.table, logger=logger) + + for i in range(1, len(summary_table_data[0])): + eval_results[summary_table_data[0] + [i]] = summary_table_data[1][i] / 100.0 + if mmcv.is_list_of(results, str): + for file_name in results: + os.remove(file_name) + return eval_results diff --git a/models/mmseg/datasets/dataset_wrappers.py b/models/mmseg/datasets/dataset_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..d6a5e957ec3b44465432617cf6e8f0b86a8a5efa --- /dev/null +++ b/models/mmseg/datasets/dataset_wrappers.py @@ -0,0 +1,50 @@ +from torch.utils.data.dataset import ConcatDataset as _ConcatDataset + +from .builder import DATASETS + + +@DATASETS.register_module() +class ConcatDataset(_ConcatDataset): + """A wrapper of concatenated dataset. + + Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but + concat the group flag for image aspect ratio. + + Args: + datasets (list[:obj:`Dataset`]): A list of datasets. + """ + + def __init__(self, datasets): + super(ConcatDataset, self).__init__(datasets) + self.CLASSES = datasets[0].CLASSES + self.PALETTE = datasets[0].PALETTE + + +@DATASETS.register_module() +class RepeatDataset(object): + """A wrapper of repeated dataset. + + The length of repeated dataset will be `times` larger than the original + dataset. This is useful when the data loading time is long but the dataset + is small. Using RepeatDataset can reduce the data loading time between + epochs. + + Args: + dataset (:obj:`Dataset`): The dataset to be repeated. + times (int): Repeat times. + """ + + def __init__(self, dataset, times): + self.dataset = dataset + self.times = times + self.CLASSES = dataset.CLASSES + self.PALETTE = dataset.PALETTE + self._ori_len = len(self.dataset) + + def __getitem__(self, idx): + """Get item from original dataset.""" + return self.dataset[idx % self._ori_len] + + def __len__(self): + """The length is multiplied by ``times``""" + return self.times * self._ori_len diff --git a/models/mmseg/datasets/drive.py b/models/mmseg/datasets/drive.py new file mode 100644 index 0000000000000000000000000000000000000000..3cbfda8ae74bdf26c5aef197ff2866a7c7ad0cfd --- /dev/null +++ b/models/mmseg/datasets/drive.py @@ -0,0 +1,27 @@ +import os.path as osp + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class DRIVEDataset(CustomDataset): + """DRIVE dataset. + + In segmentation map annotation for DRIVE, 0 stands for background, which is + included in 2 categories. ``reduce_zero_label`` is fixed to False. The + ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to + '_manual1.png'. + """ + + CLASSES = ('background', 'vessel') + + PALETTE = [[120, 120, 120], [6, 230, 230]] + + def __init__(self, **kwargs): + super(DRIVEDataset, self).__init__( + img_suffix='.png', + seg_map_suffix='_manual1.png', + reduce_zero_label=False, + **kwargs) + assert osp.exists(self.img_dir) diff --git a/models/mmseg/datasets/hrf.py b/models/mmseg/datasets/hrf.py new file mode 100644 index 0000000000000000000000000000000000000000..923203b51377f9344277fc561803d7a78bd2c684 --- /dev/null +++ b/models/mmseg/datasets/hrf.py @@ -0,0 +1,27 @@ +import os.path as osp + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class HRFDataset(CustomDataset): + """HRF dataset. + + In segmentation map annotation for HRF, 0 stands for background, which is + included in 2 categories. ``reduce_zero_label`` is fixed to False. The + ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to + '.png'. + """ + + CLASSES = ('background', 'vessel') + + PALETTE = [[120, 120, 120], [6, 230, 230]] + + def __init__(self, **kwargs): + super(HRFDataset, self).__init__( + img_suffix='.png', + seg_map_suffix='.png', + reduce_zero_label=False, + **kwargs) + assert osp.exists(self.img_dir) diff --git a/models/mmseg/datasets/mapillary.py b/models/mmseg/datasets/mapillary.py new file mode 100644 index 0000000000000000000000000000000000000000..81d94918199deecb5a05548f7543ab12241a27b0 --- /dev/null +++ b/models/mmseg/datasets/mapillary.py @@ -0,0 +1,46 @@ +from .builder import DATASETS +from .custom import CustomDataset +from IPython import embed + +@DATASETS.register_module() +class MapillaryDataset(CustomDataset): + """Mapillary dataset. + """ + CLASSES = ('Bird', 'Ground Animal', 'Curb', 'Fence', 'Guard Rail', 'Barrier', + 'Wall', 'Bike Lane', 'Crosswalk - Plain', 'Curb Cut', 'Parking', 'Pedestrian Area', + 'Rail Track', 'Road', 'Service Lane', 'Sidewalk', 'Bridge', 'Building', 'Tunnel', + 'Person', 'Bicyclist', 'Motorcyclist', 'Other Rider', 'Lane Marking - Crosswalk', + 'Lane Marking - General', 'Mountain', 'Sand', 'Sky', 'Snow', 'Terrain', 'Vegetation', + 'Water', 'Banner', 'Bench', 'Bike Rack', 'Billboard', 'Catch Basin', 'CCTV Camera', + 'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole', 'Phone Booth', 'Pothole', + 'Street Light', 'Pole', 'Traffic Sign Frame', 'Utility Pole', 'Traffic Light', + 'Traffic Sign (Back)', 'Traffic Sign (Front)', 'Trash Can', 'Bicycle', 'Boat', + 'Bus', 'Car', 'Caravan', 'Motorcycle', 'On Rails', 'Other Vehicle', 'Trailer', + 'Truck', 'Wheeled Slow', 'Car Mount', 'Ego Vehicle', 'Unlabeled') + + PALETTE = [[165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153], + [180, 165, 180], [90, 120, 150], [ + 102, 102, 156], [128, 64, 255], + [140, 140, 200], [170, 170, 170], [250, 170, 160], [96, 96, 96], + [230, 150, 140], [128, 64, 128], [ + 110, 110, 110], [244, 35, 232], + [150, 100, 100], [70, 70, 70], [150, 120, 90], [220, 20, 60], + [255, 0, 0], [255, 0, 100], [255, 0, 200], [200, 128, 128], + [255, 255, 255], [64, 170, 64], [230, 160, 50], [70, 130, 180], + [190, 255, 255], [152, 251, 152], [107, 142, 35], [0, 170, 30], + [255, 255, 128], [250, 0, 30], [100, 140, 180], [220, 220, 220], + [220, 128, 128], [222, 40, 40], [100, 170, 30], [40, 40, 40], + [33, 33, 33], [100, 128, 160], [142, 0, 0], [70, 100, 150], + [210, 170, 100], [153, 153, 153], [128, 128, 128], [0, 0, 80], + [250, 170, 30], [192, 192, 192], [220, 220, 0], [140, 140, 20], + [119, 11, 32], [150, 0, 255], [ + 0, 60, 100], [0, 0, 142], [0, 0, 90], + [0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110], [0, 0, 70], + [0, 0, 192], [32, 32, 32], [120, 10, 10], [0, 0, 0]] + + def __init__(self, **kwargs): + super(MapillaryDataset, self).__init__( + img_suffix='.jpg', + seg_map_suffix='.png', + reduce_zero_label=False, + **kwargs) \ No newline at end of file diff --git a/models/mmseg/datasets/pascal_context.py b/models/mmseg/datasets/pascal_context.py new file mode 100644 index 0000000000000000000000000000000000000000..ab42877f1e0c60099303a05021ea288f9c1c6e15 --- /dev/null +++ b/models/mmseg/datasets/pascal_context.py @@ -0,0 +1,54 @@ +import os.path as osp + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class PascalContextDataset(CustomDataset): + """PascalContext dataset. + + In segmentation map annotation for PascalContext, 0 stands for background, + which is included in 60 categories. ``reduce_zero_label`` is fixed to + False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is + fixed to '.png'. + + Args: + split (str): Split txt file for PascalContext. + """ + + CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', + 'bus', 'car', 'cat', 'chair', 'cow', 'table', 'dog', 'horse', + 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', + 'tvmonitor', 'bag', 'bed', 'bench', 'book', 'building', + 'cabinet', 'ceiling', 'cloth', 'computer', 'cup', 'door', + 'fence', 'floor', 'flower', 'food', 'grass', 'ground', + 'keyboard', 'light', 'mountain', 'mouse', 'curtain', 'platform', + 'sign', 'plate', 'road', 'rock', 'shelves', 'sidewalk', 'sky', + 'snow', 'bedclothes', 'track', 'tree', 'truck', 'wall', 'water', + 'window', 'wood') + + PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], + [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], + [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], + [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], + [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], + [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], + [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], + [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], + [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], + [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], + [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], + [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], + [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], + [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], + [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]] + + def __init__(self, split, **kwargs): + super(PascalContextDataset, self).__init__( + img_suffix='.jpg', + seg_map_suffix='.png', + split=split, + reduce_zero_label=False, + **kwargs) + assert osp.exists(self.img_dir) and self.split is not None diff --git a/models/mmseg/datasets/pipelines/__init__.py b/models/mmseg/datasets/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3bc50a88df38f9afc6d157f0077e37baf043dc11 --- /dev/null +++ b/models/mmseg/datasets/pipelines/__init__.py @@ -0,0 +1,16 @@ +from .compose import Compose +from .formating import (Collect, ImageToTensor, ToDataContainer, ToTensor, + Transpose, to_tensor) +from .loading import LoadAnnotations, LoadImageFromFile +from .test_time_aug import MultiScaleFlipAug +from .transforms import (AlignedResize, CLAHE, AdjustGamma, Normalize, Pad, + PhotoMetricDistortion, RandomCrop, RandomFlip, + RandomRotate, Rerange, Resize, RGB2Gray, SegRescale) + +__all__ = [ + 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer', + 'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile', + 'MultiScaleFlipAug', 'AlignedResize', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop', + 'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate', + 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray' +] diff --git a/models/mmseg/datasets/pipelines/compose.py b/models/mmseg/datasets/pipelines/compose.py new file mode 100644 index 0000000000000000000000000000000000000000..ca48f1c935755c486edc2744e1713e2b5ba3cdc8 --- /dev/null +++ b/models/mmseg/datasets/pipelines/compose.py @@ -0,0 +1,51 @@ +import collections + +from mmcv.utils import build_from_cfg + +from ..builder import PIPELINES + + +@PIPELINES.register_module() +class Compose(object): + """Compose multiple transforms sequentially. + + Args: + transforms (Sequence[dict | callable]): Sequence of transform object or + config dict to be composed. + """ + + def __init__(self, transforms): + assert isinstance(transforms, collections.abc.Sequence) + self.transforms = [] + for transform in transforms: + if isinstance(transform, dict): + transform = build_from_cfg(transform, PIPELINES) + self.transforms.append(transform) + elif callable(transform): + self.transforms.append(transform) + else: + raise TypeError('transform must be callable or a dict') + + def __call__(self, data): + """Call function to apply transforms sequentially. + + Args: + data (dict): A result dict contains the data to transform. + + Returns: + dict: Transformed data. + """ + + for t in self.transforms: + data = t(data) + if data is None: + return None + return data + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + format_string += '\n' + format_string += f' {t}' + format_string += '\n)' + return format_string diff --git a/models/mmseg/datasets/pipelines/formating.py b/models/mmseg/datasets/pipelines/formating.py new file mode 100644 index 0000000000000000000000000000000000000000..34061c1dd160d4b00aac8dbdc82dccf5c3883ce8 --- /dev/null +++ b/models/mmseg/datasets/pipelines/formating.py @@ -0,0 +1,288 @@ +from collections.abc import Sequence + +import mmcv +import numpy as np +import torch +from mmcv.parallel import DataContainer as DC + +from ..builder import PIPELINES + + +def to_tensor(data): + """Convert objects of various python types to :obj:`torch.Tensor`. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`, :class:`int` and :class:`float`. + + Args: + data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to + be converted. + """ + + if isinstance(data, torch.Tensor): + return data + elif isinstance(data, np.ndarray): + return torch.from_numpy(data) + elif isinstance(data, Sequence) and not mmcv.is_str(data): + return torch.tensor(data) + elif isinstance(data, int): + return torch.LongTensor([data]) + elif isinstance(data, float): + return torch.FloatTensor([data]) + else: + raise TypeError(f'type {type(data)} cannot be converted to tensor.') + + +@PIPELINES.register_module() +class ToTensor(object): + """Convert some results to :obj:`torch.Tensor` by given keys. + + Args: + keys (Sequence[str]): Keys that need to be converted to Tensor. + """ + + def __init__(self, keys): + self.keys = keys + + def __call__(self, results): + """Call function to convert data in results to :obj:`torch.Tensor`. + + Args: + results (dict): Result dict contains the data to convert. + + Returns: + dict: The result dict contains the data converted + to :obj:`torch.Tensor`. + """ + + for key in self.keys: + results[key] = to_tensor(results[key]) + return results + + def __repr__(self): + return self.__class__.__name__ + f'(keys={self.keys})' + + +@PIPELINES.register_module() +class ImageToTensor(object): + """Convert image to :obj:`torch.Tensor` by given keys. + + The dimension order of input image is (H, W, C). The pipeline will convert + it to (C, H, W). If only 2 dimension (H, W) is given, the output would be + (1, H, W). + + Args: + keys (Sequence[str]): Key of images to be converted to Tensor. + """ + + def __init__(self, keys): + self.keys = keys + + def __call__(self, results): + """Call function to convert image in results to :obj:`torch.Tensor` and + transpose the channel order. + + Args: + results (dict): Result dict contains the image data to convert. + + Returns: + dict: The result dict contains the image converted + to :obj:`torch.Tensor` and transposed to (C, H, W) order. + """ + + for key in self.keys: + img = results[key] + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + results[key] = to_tensor(img.transpose(2, 0, 1)) + return results + + def __repr__(self): + return self.__class__.__name__ + f'(keys={self.keys})' + + +@PIPELINES.register_module() +class Transpose(object): + """Transpose some results by given keys. + + Args: + keys (Sequence[str]): Keys of results to be transposed. + order (Sequence[int]): Order of transpose. + """ + + def __init__(self, keys, order): + self.keys = keys + self.order = order + + def __call__(self, results): + """Call function to convert image in results to :obj:`torch.Tensor` and + transpose the channel order. + + Args: + results (dict): Result dict contains the image data to convert. + + Returns: + dict: The result dict contains the image converted + to :obj:`torch.Tensor` and transposed to (C, H, W) order. + """ + + for key in self.keys: + results[key] = results[key].transpose(self.order) + return results + + def __repr__(self): + return self.__class__.__name__ + \ + f'(keys={self.keys}, order={self.order})' + + +@PIPELINES.register_module() +class ToDataContainer(object): + """Convert results to :obj:`mmcv.DataContainer` by given fields. + + Args: + fields (Sequence[dict]): Each field is a dict like + ``dict(key='xxx', **kwargs)``. The ``key`` in result will + be converted to :obj:`mmcv.DataContainer` with ``**kwargs``. + Default: ``(dict(key='img', stack=True), + dict(key='gt_semantic_seg'))``. + """ + + def __init__(self, + fields=(dict(key='img', + stack=True), dict(key='gt_semantic_seg'))): + self.fields = fields + + def __call__(self, results): + """Call function to convert data in results to + :obj:`mmcv.DataContainer`. + + Args: + results (dict): Result dict contains the data to convert. + + Returns: + dict: The result dict contains the data converted to + :obj:`mmcv.DataContainer`. + """ + + for field in self.fields: + field = field.copy() + key = field.pop('key') + results[key] = DC(results[key], **field) + return results + + def __repr__(self): + return self.__class__.__name__ + f'(fields={self.fields})' + + +@PIPELINES.register_module() +class DefaultFormatBundle(object): + """Default formatting bundle. + + It simplifies the pipeline of formatting common fields, including "img" + and "gt_semantic_seg". These fields are formatted as follows. + + - img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True) + - gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor, + (3)to DataContainer (stack=True) + """ + + def __call__(self, results): + """Call function to transform and format common fields in results. + + Args: + results (dict): Result dict contains the data to convert. + + Returns: + dict: The result dict contains the data that is formatted with + default bundle. + """ + + if 'img' in results: + img = results['img'] + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + img = np.ascontiguousarray(img.transpose(2, 0, 1)) + results['img'] = DC(to_tensor(img), stack=True) + if 'gt_semantic_seg' in results: + # convert to long + results['gt_semantic_seg'] = DC( + to_tensor(results['gt_semantic_seg'][None, + ...].astype(np.int64)), + stack=True) + return results + + def __repr__(self): + return self.__class__.__name__ + + +@PIPELINES.register_module() +class Collect(object): + """Collect data from the loader relevant to the specific task. + + This is usually the last stage of the data loader pipeline. Typically keys + is set to some subset of "img", "gt_semantic_seg". + + The "img_meta" item is always populated. The contents of the "img_meta" + dictionary depends on "meta_keys". By default this includes: + + - "img_shape": shape of the image input to the network as a tuple + (h, w, c). Note that images may be zero padded on the bottom/right + if the batch tensor is larger than this shape. + + - "scale_factor": a float indicating the preprocessing scale + + - "flip": a boolean indicating if image flip transform was used + + - "filename": path to the image file + + - "ori_shape": original shape of the image as a tuple (h, w, c) + + - "pad_shape": image shape after padding + + - "img_norm_cfg": a dict of normalization information: + - mean - per channel mean subtraction + - std - per channel std divisor + - to_rgb - bool indicating if bgr was converted to rgb + + Args: + keys (Sequence[str]): Keys of results to be collected in ``data``. + meta_keys (Sequence[str], optional): Meta keys to be converted to + ``mmcv.DataContainer`` and collected in ``data[img_metas]``. + Default: ``('filename', 'ori_filename', 'ori_shape', 'img_shape', + 'pad_shape', 'scale_factor', 'flip', 'flip_direction', + 'img_norm_cfg')`` + """ + + def __init__(self, + keys, + meta_keys=('filename', 'ori_filename', 'ori_shape', + 'img_shape', 'pad_shape', 'scale_factor', 'flip', + 'flip_direction', 'img_norm_cfg')): + self.keys = keys + self.meta_keys = meta_keys + + def __call__(self, results): + """Call function to collect keys in results. The keys in ``meta_keys`` + will be converted to :obj:mmcv.DataContainer. + + Args: + results (dict): Result dict contains the data to collect. + + Returns: + dict: The result dict contains the following keys + - keys in``self.keys`` + - ``img_metas`` + """ + + data = {} + img_meta = {} + for key in self.meta_keys: + img_meta[key] = results[key] + data['img_metas'] = DC(img_meta, cpu_only=True) + for key in self.keys: + data[key] = results[key] + return data + + def __repr__(self): + return self.__class__.__name__ + \ + f'(keys={self.keys}, meta_keys={self.meta_keys})' diff --git a/models/mmseg/datasets/pipelines/loading.py b/models/mmseg/datasets/pipelines/loading.py new file mode 100644 index 0000000000000000000000000000000000000000..fdfc496ba96828a435febbef958fdae499d034f7 --- /dev/null +++ b/models/mmseg/datasets/pipelines/loading.py @@ -0,0 +1,153 @@ +import os.path as osp + +import mmcv +import numpy as np + +from ..builder import PIPELINES + + +@PIPELINES.register_module() +class LoadImageFromFile(object): + """Load an image from file. + + Required keys are "img_prefix" and "img_info" (a dict that must contain the + key "filename"). Added or updated keys are "filename", "img", "img_shape", + "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`), + "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1). + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an uint8 array. + Defaults to False. + color_type (str): The flag argument for :func:`mmcv.imfrombytes`. + Defaults to 'color'. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. + Defaults to ``dict(backend='disk')``. + imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default: + 'cv2' + """ + + def __init__(self, + to_float32=False, + color_type='color', + file_client_args=dict(backend='disk'), + imdecode_backend='cv2'): + self.to_float32 = to_float32 + self.color_type = color_type + self.file_client_args = file_client_args.copy() + self.file_client = None + self.imdecode_backend = imdecode_backend + + def __call__(self, results): + """Call functions to load image and get image meta information. + + Args: + results (dict): Result dict from :obj:`mmseg.CustomDataset`. + + Returns: + dict: The dict contains loaded image and meta information. + """ + + if self.file_client is None: + self.file_client = mmcv.FileClient(**self.file_client_args) + + if results.get('img_prefix') is not None: + filename = osp.join(results['img_prefix'], + results['img_info']['filename']) + else: + filename = results['img_info']['filename'] + img_bytes = self.file_client.get(filename) + img = mmcv.imfrombytes( + img_bytes, flag=self.color_type, backend=self.imdecode_backend) + if self.to_float32: + img = img.astype(np.float32) + + results['filename'] = filename + results['ori_filename'] = results['img_info']['filename'] + results['img'] = img + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + # Set initial values for default meta_keys + results['pad_shape'] = img.shape + results['scale_factor'] = 1.0 + num_channels = 1 if len(img.shape) < 3 else img.shape[2] + results['img_norm_cfg'] = dict( + mean=np.zeros(num_channels, dtype=np.float32), + std=np.ones(num_channels, dtype=np.float32), + to_rgb=False) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(to_float32={self.to_float32},' + repr_str += f"color_type='{self.color_type}'," + repr_str += f"imdecode_backend='{self.imdecode_backend}')" + return repr_str + + +@PIPELINES.register_module() +class LoadAnnotations(object): + """Load annotations for semantic segmentation. + + Args: + reduce_zero_label (bool): Whether reduce all label value by 1. + Usually used for datasets where 0 is background label. + Default: False. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. + Defaults to ``dict(backend='disk')``. + imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default: + 'pillow' + """ + + def __init__(self, + reduce_zero_label=False, + file_client_args=dict(backend='disk'), + imdecode_backend='pillow'): + self.reduce_zero_label = reduce_zero_label + self.file_client_args = file_client_args.copy() + self.file_client = None + self.imdecode_backend = imdecode_backend + + def __call__(self, results): + """Call function to load multiple types annotations. + + Args: + results (dict): Result dict from :obj:`mmseg.CustomDataset`. + + Returns: + dict: The dict contains loaded semantic segmentation annotations. + """ + + if self.file_client is None: + self.file_client = mmcv.FileClient(**self.file_client_args) + + if results.get('seg_prefix', None) is not None: + filename = osp.join(results['seg_prefix'], + results['ann_info']['seg_map']) + else: + filename = results['ann_info']['seg_map'] + img_bytes = self.file_client.get(filename) + gt_semantic_seg = mmcv.imfrombytes( + img_bytes, flag='unchanged', + backend=self.imdecode_backend).squeeze().astype(np.uint8) + # modify if custom classes + if results.get('label_map', None) is not None: + for old_id, new_id in results['label_map'].items(): + gt_semantic_seg[gt_semantic_seg == old_id] = new_id + # reduce zero_label + if self.reduce_zero_label: + # avoid using underflow conversion + gt_semantic_seg[gt_semantic_seg == 0] = 255 + gt_semantic_seg = gt_semantic_seg - 1 + gt_semantic_seg[gt_semantic_seg == 254] = 255 + results['gt_semantic_seg'] = gt_semantic_seg + results['seg_fields'].append('gt_semantic_seg') + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(reduce_zero_label={self.reduce_zero_label},' + repr_str += f"imdecode_backend='{self.imdecode_backend}')" + return repr_str diff --git a/models/mmseg/datasets/pipelines/test_time_aug.py b/models/mmseg/datasets/pipelines/test_time_aug.py new file mode 100644 index 0000000000000000000000000000000000000000..473a12bc86b57e564c415ff8bdb1e431425370db --- /dev/null +++ b/models/mmseg/datasets/pipelines/test_time_aug.py @@ -0,0 +1,133 @@ +import warnings + +import mmcv + +from ..builder import PIPELINES +from .compose import Compose + + +@PIPELINES.register_module() +class MultiScaleFlipAug(object): + """Test-time augmentation with multiple scales and flipping. + + An example configuration is as followed: + + .. code-block:: + + img_scale=(2048, 1024), + img_ratios=[0.5, 1.0], + flip=True, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ] + + After MultiScaleFLipAug with above configuration, the results are wrapped + into lists of the same length as followed: + + .. code-block:: + + dict( + img=[...], + img_shape=[...], + scale=[(1024, 512), (1024, 512), (2048, 1024), (2048, 1024)] + flip=[False, True, False, True] + ... + ) + + Args: + transforms (list[dict]): Transforms to apply in each augmentation. + img_scale (None | tuple | list[tuple]): Images scales for resizing. + img_ratios (float | list[float]): Image ratios for resizing + flip (bool): Whether apply flip augmentation. Default: False. + flip_direction (str | list[str]): Flip augmentation directions, + options are "horizontal" and "vertical". If flip_direction is list, + multiple flip augmentations will be applied. + It has no effect when flip == False. Default: "horizontal". + """ + + def __init__(self, + transforms, + img_scale, + img_ratios=None, + flip=False, + flip_direction='horizontal'): + self.transforms = Compose(transforms) + if img_ratios is not None: + img_ratios = img_ratios if isinstance(img_ratios, + list) else [img_ratios] + assert mmcv.is_list_of(img_ratios, float) + if img_scale is None: + # mode 1: given img_scale=None and a range of image ratio + self.img_scale = None + assert mmcv.is_list_of(img_ratios, float) + elif isinstance(img_scale, tuple) and mmcv.is_list_of( + img_ratios, float): + assert len(img_scale) == 2 + # mode 2: given a scale and a range of image ratio + self.img_scale = [(int(img_scale[0] * ratio), + int(img_scale[1] * ratio)) + for ratio in img_ratios] + else: + # mode 3: given multiple scales + self.img_scale = img_scale if isinstance(img_scale, + list) else [img_scale] + assert mmcv.is_list_of(self.img_scale, tuple) or self.img_scale is None + self.flip = flip + self.img_ratios = img_ratios + self.flip_direction = flip_direction if isinstance( + flip_direction, list) else [flip_direction] + assert mmcv.is_list_of(self.flip_direction, str) + if not self.flip and self.flip_direction != ['horizontal']: + warnings.warn( + 'flip_direction has no effect when flip is set to False') + if (self.flip + and not any([t['type'] == 'RandomFlip' for t in transforms])): + warnings.warn( + 'flip has no effect when RandomFlip is not in transforms') + + def __call__(self, results): + """Call function to apply test time augment transforms on results. + + Args: + results (dict): Result dict contains the data to transform. + + Returns: + dict[str: list]: The augmented data, where each value is wrapped + into a list. + """ + + aug_data = [] + if self.img_scale is None and mmcv.is_list_of(self.img_ratios, float): + h, w = results['img'].shape[:2] + img_scale = [(int(w * ratio), int(h * ratio)) + for ratio in self.img_ratios] + else: + img_scale = self.img_scale + flip_aug = [False, True] if self.flip else [False] + for scale in img_scale: + for flip in flip_aug: + for direction in self.flip_direction: + _results = results.copy() + _results['scale'] = scale + _results['flip'] = flip + _results['flip_direction'] = direction + data = self.transforms(_results) + aug_data.append(data) + # list of dict to dict of list + aug_data_dict = {key: [] for key in aug_data[0]} + for data in aug_data: + for key, val in data.items(): + aug_data_dict[key].append(val) + return aug_data_dict + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(transforms={self.transforms}, ' + repr_str += f'img_scale={self.img_scale}, flip={self.flip})' + repr_str += f'flip_direction={self.flip_direction}' + return repr_str diff --git a/models/mmseg/datasets/pipelines/transforms.py b/models/mmseg/datasets/pipelines/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..4f41e3981dc4caa166721e9eadc687b4327c8256 --- /dev/null +++ b/models/mmseg/datasets/pipelines/transforms.py @@ -0,0 +1,1215 @@ +import mmcv +import numpy as np +from mmcv.utils import deprecated_api_warning, is_tuple_of +from numpy import random + +from ..builder import PIPELINES +from IPython import embed + +@PIPELINES.register_module() +class AlignedResize(object): + """Resize images & seg. + + This class is based on ``Resize``, the only difference is + it ensure the long and short sides are divisible by ``size_divisor``. + """ + + def __init__(self, + img_scale=None, + multiscale_mode='range', + ratio_range=None, + keep_ratio=True, + size_divisor=32): + if img_scale is None: + self.img_scale = None + else: + if isinstance(img_scale, list): + self.img_scale = img_scale + else: + self.img_scale = [img_scale] + assert mmcv.is_list_of(self.img_scale, tuple) + + if ratio_range is not None: + # mode 1: given img_scale=None and a range of image ratio + # mode 2: given a scale and a range of image ratio + assert self.img_scale is None or len(self.img_scale) == 1 + else: + # mode 3 and 4: given multiple scales or a range of scales + assert multiscale_mode in ['value', 'range'] + + self.multiscale_mode = multiscale_mode + self.ratio_range = ratio_range + self.keep_ratio = keep_ratio + self.size_divisor = size_divisor + + @staticmethod + def random_select(img_scales): + """Randomly select an img_scale from given candidates. + + Args: + img_scales (list[tuple]): Images scales for selection. + + Returns: + (tuple, int): Returns a tuple ``(img_scale, scale_dix)``, + where ``img_scale`` is the selected image scale and + ``scale_idx`` is the selected index in the given candidates. + """ + + assert mmcv.is_list_of(img_scales, tuple) + scale_idx = np.random.randint(len(img_scales)) + img_scale = img_scales[scale_idx] + return img_scale, scale_idx + + @staticmethod + def random_sample(img_scales): + """Randomly sample an img_scale when ``multiscale_mode=='range'``. + + Args: + img_scales (list[tuple]): Images scale range for sampling. + There must be two tuples in img_scales, which specify the lower + and uper bound of image scales. + + Returns: + (tuple, None): Returns a tuple ``(img_scale, None)``, where + ``img_scale`` is sampled scale and None is just a placeholder + to be consistent with :func:`random_select`. + """ + + assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2 + img_scale_long = [max(s) for s in img_scales] + img_scale_short = [min(s) for s in img_scales] + long_edge = np.random.randint( + min(img_scale_long), + max(img_scale_long) + 1) + short_edge = np.random.randint( + min(img_scale_short), + max(img_scale_short) + 1) + img_scale = (long_edge, short_edge) + return img_scale, None + + @staticmethod + def random_sample_ratio(img_scale, ratio_range): + """Randomly sample an img_scale when ``ratio_range`` is specified. + + A ratio will be randomly sampled from the range specified by + ``ratio_range``. Then it would be multiplied with ``img_scale`` to + generate sampled scale. + + Args: + img_scale (tuple): Images scale base to multiply with ratio. + ratio_range (tuple[float]): The minimum and maximum ratio to scale + the ``img_scale``. + + Returns: + (tuple, None): Returns a tuple ``(scale, None)``, where + ``scale`` is sampled ratio multiplied with ``img_scale`` and + None is just a placeholder to be consistent with + :func:`random_select`. + """ + + assert isinstance(img_scale, tuple) and len(img_scale) == 2 + min_ratio, max_ratio = ratio_range + assert min_ratio <= max_ratio + ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio + scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio) + return scale, None + + def _random_scale(self, results): + """Randomly sample an img_scale according to ``ratio_range`` and + ``multiscale_mode``. + + If ``ratio_range`` is specified, a ratio will be sampled and be + multiplied with ``img_scale``. + If multiple scales are specified by ``img_scale``, a scale will be + sampled according to ``multiscale_mode``. + Otherwise, single scale will be used. + + Args: + results (dict): Result dict from :obj:`dataset`. + + Returns: + dict: Two new keys 'scale` and 'scale_idx` are added into + ``results``, which would be used by subsequent pipelines. + """ + + if self.ratio_range is not None: + if self.img_scale is None: + h, w = results['img'].shape[:2] + scale, scale_idx = self.random_sample_ratio((w, h), + self.ratio_range) + else: + scale, scale_idx = self.random_sample_ratio( + self.img_scale[0], self.ratio_range) + elif len(self.img_scale) == 1: + scale, scale_idx = self.img_scale[0], 0 + elif self.multiscale_mode == 'range': + scale, scale_idx = self.random_sample(self.img_scale) + elif self.multiscale_mode == 'value': + scale, scale_idx = self.random_select(self.img_scale) + else: + raise NotImplementedError + + results['scale'] = scale + results['scale_idx'] = scale_idx + + def _align(self, img, size_divisor, interpolation=None): + align_h = int(np.ceil(img.shape[0] / size_divisor)) * size_divisor + align_w = int(np.ceil(img.shape[1] / size_divisor)) * size_divisor + if interpolation == None: + img = mmcv.imresize(img, (align_w, align_h)) + else: + img = mmcv.imresize(img, (align_w, align_h), interpolation=interpolation) + return img + + def _resize_img(self, results): + """Resize images with ``results['scale']``.""" + if self.keep_ratio: + img, scale_factor = mmcv.imrescale( + results['img'], results['scale'], return_scale=True) + #### align #### + img = self._align(img, self.size_divisor) + # the w_scale and h_scale has minor difference + # a real fix should be done in the mmcv.imrescale in the future + new_h, new_w = img.shape[:2] + h, w = results['img'].shape[:2] + w_scale = new_w / w + h_scale = new_h / h + else: + img, w_scale, h_scale = mmcv.imresize( + results['img'], results['scale'], return_scale=True) + + h, w = img.shape[:2] + assert int(np.ceil(h / self.size_divisor)) * self.size_divisor == h and \ + int(np.ceil(w / self.size_divisor)) * self.size_divisor == w, \ + "img size not align. h:{} w:{}".format(h,w) + scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], + dtype=np.float32) + results['img'] = img + results['img_shape'] = img.shape + results['pad_shape'] = img.shape # in case that there is no padding + results['scale_factor'] = scale_factor + results['keep_ratio'] = self.keep_ratio + + def _resize_seg(self, results): + """Resize semantic segmentation map with ``results['scale']``.""" + for key in results.get('seg_fields', []): + if self.keep_ratio: + gt_seg = mmcv.imrescale( + results[key], results['scale'], interpolation='nearest') + gt_seg = self._align(gt_seg, self.size_divisor, interpolation='nearest') + else: + gt_seg = mmcv.imresize( + results[key], results['scale'], interpolation='nearest') + h, w = gt_seg.shape[:2] + assert int(np.ceil(h / self.size_divisor)) * self.size_divisor == h and \ + int(np.ceil(w / self.size_divisor)) * self.size_divisor == w, \ + "gt_seg size not align. h:{} w:{}".format(h, w) + results[key] = gt_seg + + def __call__(self, results): + """Call function to resize images, bounding boxes, masks, semantic + segmentation map. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor', + 'keep_ratio' keys are added into result dict. + """ + + if 'scale' not in results: + self._random_scale(results) + self._resize_img(results) + self._resize_seg(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += (f'(img_scale={self.img_scale}, ' + f'multiscale_mode={self.multiscale_mode}, ' + f'ratio_range={self.ratio_range}, ' + f'keep_ratio={self.keep_ratio})') + return repr_str + + +@PIPELINES.register_module() +class Resize(object): + """Resize images & seg. + + This transform resizes the input image to some scale. If the input dict + contains the key "scale", then the scale in the input dict is used, + otherwise the specified scale in the init method is used. + + ``img_scale`` can be Nong, a tuple (single-scale) or a list of tuple + (multi-scale). There are 4 multiscale modes: + + - ``ratio_range is not None``: + 1. When img_scale is None, img_scale is the shape of image in results + (img_scale = results['img'].shape[:2]) and the image is resized based + on the original size. (mode 1) + 2. When img_scale is a tuple (single-scale), randomly sample a ratio from + the ratio range and multiply it with the image scale. (mode 2) + + - ``ratio_range is None and multiscale_mode == "range"``: randomly sample a + scale from the a range. (mode 3) + + - ``ratio_range is None and multiscale_mode == "value"``: randomly sample a + scale from multiple scales. (mode 4) + + Args: + img_scale (tuple or list[tuple]): Images scales for resizing. + multiscale_mode (str): Either "range" or "value". + ratio_range (tuple[float]): (min_ratio, max_ratio) + keep_ratio (bool): Whether to keep the aspect ratio when resizing the + image. + """ + + def __init__(self, + img_scale=None, + multiscale_mode='range', + ratio_range=None, + keep_ratio=True): + if img_scale is None: + self.img_scale = None + else: + if isinstance(img_scale, list): + self.img_scale = img_scale + else: + self.img_scale = [img_scale] + assert mmcv.is_list_of(self.img_scale, tuple) + + if ratio_range is not None: + # mode 1: given img_scale=None and a range of image ratio + # mode 2: given a scale and a range of image ratio + assert self.img_scale is None or len(self.img_scale) == 1 + else: + # mode 3 and 4: given multiple scales or a range of scales + assert multiscale_mode in ['value', 'range'] + + self.multiscale_mode = multiscale_mode + self.ratio_range = ratio_range + self.keep_ratio = keep_ratio + + @staticmethod + def random_select(img_scales): + """Randomly select an img_scale from given candidates. + + Args: + img_scales (list[tuple]): Images scales for selection. + + Returns: + (tuple, int): Returns a tuple ``(img_scale, scale_dix)``, + where ``img_scale`` is the selected image scale and + ``scale_idx`` is the selected index in the given candidates. + """ + + assert mmcv.is_list_of(img_scales, tuple) + scale_idx = np.random.randint(len(img_scales)) + img_scale = img_scales[scale_idx] + return img_scale, scale_idx + + @staticmethod + def random_sample(img_scales): + """Randomly sample an img_scale when ``multiscale_mode=='range'``. + + Args: + img_scales (list[tuple]): Images scale range for sampling. + There must be two tuples in img_scales, which specify the lower + and uper bound of image scales. + + Returns: + (tuple, None): Returns a tuple ``(img_scale, None)``, where + ``img_scale`` is sampled scale and None is just a placeholder + to be consistent with :func:`random_select`. + """ + + assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2 + img_scale_long = [max(s) for s in img_scales] + img_scale_short = [min(s) for s in img_scales] + long_edge = np.random.randint( + min(img_scale_long), + max(img_scale_long) + 1) + short_edge = np.random.randint( + min(img_scale_short), + max(img_scale_short) + 1) + img_scale = (long_edge, short_edge) + return img_scale, None + + @staticmethod + def random_sample_ratio(img_scale, ratio_range): + """Randomly sample an img_scale when ``ratio_range`` is specified. + + A ratio will be randomly sampled from the range specified by + ``ratio_range``. Then it would be multiplied with ``img_scale`` to + generate sampled scale. + + Args: + img_scale (tuple): Images scale base to multiply with ratio. + ratio_range (tuple[float]): The minimum and maximum ratio to scale + the ``img_scale``. + + Returns: + (tuple, None): Returns a tuple ``(scale, None)``, where + ``scale`` is sampled ratio multiplied with ``img_scale`` and + None is just a placeholder to be consistent with + :func:`random_select`. + """ + + assert isinstance(img_scale, tuple) and len(img_scale) == 2 + min_ratio, max_ratio = ratio_range + assert min_ratio <= max_ratio + ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio + scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio) + return scale, None + + def _random_scale(self, results): + """Randomly sample an img_scale according to ``ratio_range`` and + ``multiscale_mode``. + + If ``ratio_range`` is specified, a ratio will be sampled and be + multiplied with ``img_scale``. + If multiple scales are specified by ``img_scale``, a scale will be + sampled according to ``multiscale_mode``. + Otherwise, single scale will be used. + + Args: + results (dict): Result dict from :obj:`dataset`. + + Returns: + dict: Two new keys 'scale` and 'scale_idx` are added into + ``results``, which would be used by subsequent pipelines. + """ + + if self.ratio_range is not None: + if self.img_scale is None: + h, w = results['img'].shape[:2] + scale, scale_idx = self.random_sample_ratio((w, h), + self.ratio_range) + else: + scale, scale_idx = self.random_sample_ratio( + self.img_scale[0], self.ratio_range) + elif len(self.img_scale) == 1: + scale, scale_idx = self.img_scale[0], 0 + elif self.multiscale_mode == 'range': + scale, scale_idx = self.random_sample(self.img_scale) + elif self.multiscale_mode == 'value': + scale, scale_idx = self.random_select(self.img_scale) + else: + raise NotImplementedError + + results['scale'] = scale + results['scale_idx'] = scale_idx + + def _resize_img(self, results): + """Resize images with ``results['scale']``.""" + if self.keep_ratio: + img, scale_factor = mmcv.imrescale( + results['img'], results['scale'], return_scale=True) + # the w_scale and h_scale has minor difference + # a real fix should be done in the mmcv.imrescale in the future + new_h, new_w = img.shape[:2] + h, w = results['img'].shape[:2] + w_scale = new_w / w + h_scale = new_h / h + else: + img, w_scale, h_scale = mmcv.imresize( + results['img'], results['scale'], return_scale=True) + scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], + dtype=np.float32) + results['img'] = img + results['img_shape'] = img.shape + results['pad_shape'] = img.shape # in case that there is no padding + results['scale_factor'] = scale_factor + results['keep_ratio'] = self.keep_ratio + + def _resize_seg(self, results): + """Resize semantic segmentation map with ``results['scale']``.""" + for key in results.get('seg_fields', []): + if self.keep_ratio: + gt_seg = mmcv.imrescale( + results[key], results['scale'], interpolation='nearest') + else: + gt_seg = mmcv.imresize( + results[key], results['scale'], interpolation='nearest') + results[key] = gt_seg + + def __call__(self, results): + """Call function to resize images, bounding boxes, masks, semantic + segmentation map. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor', + 'keep_ratio' keys are added into result dict. + """ + + if 'scale' not in results: + self._random_scale(results) + self._resize_img(results) + self._resize_seg(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += (f'(img_scale={self.img_scale}, ' + f'multiscale_mode={self.multiscale_mode}, ' + f'ratio_range={self.ratio_range}, ' + f'keep_ratio={self.keep_ratio})') + return repr_str + + +@PIPELINES.register_module() +class RandomFlip(object): + """Flip the image & seg. + + If the input dict contains the key "flip", then the flag will be used, + otherwise it will be randomly decided by a ratio specified in the init + method. + + Args: + prob (float, optional): The flipping probability. Default: None. + direction(str, optional): The flipping direction. Options are + 'horizontal' and 'vertical'. Default: 'horizontal'. + """ + + @deprecated_api_warning({'flip_ratio': 'prob'}, cls_name='RandomFlip') + def __init__(self, prob=None, direction='horizontal'): + self.prob = prob + self.direction = direction + if prob is not None: + assert prob >= 0 and prob <= 1 + assert direction in ['horizontal', 'vertical'] + + def __call__(self, results): + """Call function to flip bounding boxes, masks, semantic segmentation + maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Flipped results, 'flip', 'flip_direction' keys are added into + result dict. + """ + + if 'flip' not in results: + flip = True if np.random.rand() < self.prob else False + results['flip'] = flip + if 'flip_direction' not in results: + results['flip_direction'] = self.direction + if results['flip']: + # flip image + results['img'] = mmcv.imflip( + results['img'], direction=results['flip_direction']) + + # flip segs + for key in results.get('seg_fields', []): + # use copy() to make numpy stride positive + results[key] = mmcv.imflip( + results[key], direction=results['flip_direction']).copy() + return results + + def __repr__(self): + return self.__class__.__name__ + f'(prob={self.prob})' + + +@PIPELINES.register_module() +class Pad(object): + """Pad the image & mask. + + There are two padding modes: (1) pad to a fixed size and (2) pad to the + minimum size that is divisible by some number. + Added keys are "pad_shape", "pad_fixed_size", "pad_size_divisor", + + Args: + size (tuple, optional): Fixed padding size. + size_divisor (int, optional): The divisor of padded size. + pad_val (float, optional): Padding value. Default: 0. + seg_pad_val (float, optional): Padding value of segmentation map. + Default: 255. + """ + + def __init__(self, + size=None, + size_divisor=None, + pad_val=0, + seg_pad_val=255): + self.size = size + self.size_divisor = size_divisor + self.pad_val = pad_val + self.seg_pad_val = seg_pad_val + # only one of size and size_divisor should be valid + assert size is not None or size_divisor is not None + assert size is None or size_divisor is None + + def _pad_img(self, results): + """Pad images according to ``self.size``.""" + if self.size is not None: + padded_img = mmcv.impad( + results['img'], shape=self.size, pad_val=self.pad_val) + elif self.size_divisor is not None: + padded_img = mmcv.impad_to_multiple( + results['img'], self.size_divisor, pad_val=self.pad_val) + results['img'] = padded_img + results['pad_shape'] = padded_img.shape + results['pad_fixed_size'] = self.size + results['pad_size_divisor'] = self.size_divisor + + def _pad_seg(self, results): + """Pad masks according to ``results['pad_shape']``.""" + for key in results.get('seg_fields', []): + results[key] = mmcv.impad( + results[key], + shape=results['pad_shape'][:2], + pad_val=self.seg_pad_val) + + def __call__(self, results): + """Call function to pad images, masks, semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Updated result dict. + """ + + self._pad_img(results) + self._pad_seg(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(size={self.size}, size_divisor={self.size_divisor}, ' \ + f'pad_val={self.pad_val})' + return repr_str + + +@PIPELINES.register_module() +class Normalize(object): + """Normalize the image. + + Added key is "img_norm_cfg". + + Args: + mean (sequence): Mean values of 3 channels. + std (sequence): Std values of 3 channels. + to_rgb (bool): Whether to convert the image from BGR to RGB, + default is true. + """ + + def __init__(self, mean, std, to_rgb=True): + self.mean = np.array(mean, dtype=np.float32) + self.std = np.array(std, dtype=np.float32) + self.to_rgb = to_rgb + + def __call__(self, results): + """Call function to normalize images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Normalized results, 'img_norm_cfg' key is added into + result dict. + """ + + results['img'] = mmcv.imnormalize(results['img'], self.mean, self.std, + self.to_rgb) + results['img_norm_cfg'] = dict( + mean=self.mean, std=self.std, to_rgb=self.to_rgb) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(mean={self.mean}, std={self.std}, to_rgb=' \ + f'{self.to_rgb})' + return repr_str + + +@PIPELINES.register_module() +class Rerange(object): + """Rerange the image pixel value. + + Args: + min_value (float or int): Minimum value of the reranged image. + Default: 0. + max_value (float or int): Maximum value of the reranged image. + Default: 255. + """ + + def __init__(self, min_value=0, max_value=255): + assert isinstance(min_value, float) or isinstance(min_value, int) + assert isinstance(max_value, float) or isinstance(max_value, int) + assert min_value < max_value + self.min_value = min_value + self.max_value = max_value + + def __call__(self, results): + """Call function to rerange images. + + Args: + results (dict): Result dict from loading pipeline. + Returns: + dict: Reranged results. + """ + + img = results['img'] + img_min_value = np.min(img) + img_max_value = np.max(img) + + assert img_min_value < img_max_value + # rerange to [0, 1] + img = (img - img_min_value) / (img_max_value - img_min_value) + # rerange to [min_value, max_value] + img = img * (self.max_value - self.min_value) + self.min_value + results['img'] = img + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(min_value={self.min_value}, max_value={self.max_value})' + return repr_str + + +@PIPELINES.register_module() +class CLAHE(object): + """Use CLAHE method to process the image. + + See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J]. + Graphics Gems, 1994:474-485.` for more information. + + Args: + clip_limit (float): Threshold for contrast limiting. Default: 40.0. + tile_grid_size (tuple[int]): Size of grid for histogram equalization. + Input image will be divided into equally sized rectangular tiles. + It defines the number of tiles in row and column. Default: (8, 8). + """ + + def __init__(self, clip_limit=40.0, tile_grid_size=(8, 8)): + assert isinstance(clip_limit, (float, int)) + self.clip_limit = clip_limit + assert is_tuple_of(tile_grid_size, int) + assert len(tile_grid_size) == 2 + self.tile_grid_size = tile_grid_size + + def __call__(self, results): + """Call function to Use CLAHE method process images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Processed results. + """ + + for i in range(results['img'].shape[2]): + results['img'][:, :, i] = mmcv.clahe( + np.array(results['img'][:, :, i], dtype=np.uint8), + self.clip_limit, self.tile_grid_size) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(clip_limit={self.clip_limit}, '\ + f'tile_grid_size={self.tile_grid_size})' + return repr_str + + +@PIPELINES.register_module() +class RandomCrop(object): + """Random crop the image & seg. + + Args: + crop_size (tuple): Expected size after cropping, (h, w). + cat_max_ratio (float): The maximum ratio that single category could + occupy. + """ + + def __init__(self, crop_size, cat_max_ratio=1., ignore_index=255): + assert crop_size[0] > 0 and crop_size[1] > 0 + self.crop_size = crop_size + self.cat_max_ratio = cat_max_ratio + self.ignore_index = ignore_index + + def get_crop_bbox(self, img): + """Randomly get a crop bounding box.""" + margin_h = max(img.shape[0] - self.crop_size[0], 0) + margin_w = max(img.shape[1] - self.crop_size[1], 0) + offset_h = np.random.randint(0, margin_h + 1) + offset_w = np.random.randint(0, margin_w + 1) + crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0] + crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1] + + return crop_y1, crop_y2, crop_x1, crop_x2 + + def crop(self, img, crop_bbox): + """Crop from ``img``""" + crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox + img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...] + return img + + def __call__(self, results): + """Call function to randomly crop images, semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Randomly cropped results, 'img_shape' key in result dict is + updated according to crop size. + """ + + img = results['img'] + crop_bbox = self.get_crop_bbox(img) + if self.cat_max_ratio < 1.: + # Repeat 10 times + for _ in range(10): + seg_temp = self.crop(results['gt_semantic_seg'], crop_bbox) + labels, cnt = np.unique(seg_temp, return_counts=True) + cnt = cnt[labels != self.ignore_index] + if len(cnt) > 1 and np.max(cnt) / np.sum( + cnt) < self.cat_max_ratio: + break + crop_bbox = self.get_crop_bbox(img) + + # crop the image + img = self.crop(img, crop_bbox) + img_shape = img.shape + results['img'] = img + results['img_shape'] = img_shape + + # crop semantic seg + for key in results.get('seg_fields', []): + results[key] = self.crop(results[key], crop_bbox) + + return results + + def __repr__(self): + return self.__class__.__name__ + f'(crop_size={self.crop_size})' + +@PIPELINES.register_module() +class CenterCrop(object): + """Center crop the image & seg. + Args: + crop_size (tuple): Expected size after cropping, (h, w). + """ + + def __init__(self, crop_size, ignore_index=255): + assert crop_size[0] > 0 and crop_size[1] > 0 + self.crop_size = crop_size + self.ignore_index = ignore_index + + def get_crop_bbox(self, img): + """Randomly get a crop bounding box.""" + margin_h = max(img.shape[0] - self.crop_size[0], 0) + margin_w = max(img.shape[1] - self.crop_size[1], 0) + offset_h = margin_h // 2#np.random.randint(0, margin_h + 1) + offset_w = margin_w // 2#np.random.randint(0, margin_w + 1) + crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0] + crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1] + + return crop_y1, crop_y2, crop_x1, crop_x2 + + def crop(self, img, crop_bbox): + """Crop from ``img``""" + crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox + img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...] + return img + + def __call__(self, results): + """Call function to randomly crop images, semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Randomly cropped results, 'img_shape' key in result dict is + updated according to crop size. + """ + + img = results['img'] + crop_bbox = self.get_crop_bbox(img) + + # crop the image + img = self.crop(img, crop_bbox) + img_shape = img.shape + results['img'] = img + results['img_shape'] = img_shape + + # crop semantic seg + for key in results.get('seg_fields', []): + results[key] = self.crop(results[key], crop_bbox) + + return results + + def __repr__(self): + return self.__class__.__name__ + f'(crop_size={self.crop_size})' + + +@PIPELINES.register_module() +class RandomRotate(object): + """Rotate the image & seg. + + Args: + prob (float): The rotation probability. + degree (float, tuple[float]): Range of degrees to select from. If + degree is a number instead of tuple like (min, max), + the range of degree will be (``-degree``, ``+degree``) + pad_val (float, optional): Padding value of image. Default: 0. + seg_pad_val (float, optional): Padding value of segmentation map. + Default: 255. + center (tuple[float], optional): Center point (w, h) of the rotation in + the source image. If not specified, the center of the image will be + used. Default: None. + auto_bound (bool): Whether to adjust the image size to cover the whole + rotated image. Default: False + """ + + def __init__(self, + prob, + degree, + pad_val=0, + seg_pad_val=255, + center=None, + auto_bound=False): + self.prob = prob + assert prob >= 0 and prob <= 1 + if isinstance(degree, (float, int)): + assert degree > 0, f'degree {degree} should be positive' + self.degree = (-degree, degree) + else: + self.degree = degree + assert len(self.degree) == 2, f'degree {self.degree} should be a ' \ + f'tuple of (min, max)' + self.pal_val = pad_val + self.seg_pad_val = seg_pad_val + self.center = center + self.auto_bound = auto_bound + + def __call__(self, results): + """Call function to rotate image, semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Rotated results. + """ + + rotate = True if np.random.rand() < self.prob else False + degree = np.random.uniform(min(*self.degree), max(*self.degree)) + if rotate: + # rotate image + results['img'] = mmcv.imrotate( + results['img'], + angle=degree, + border_value=self.pal_val, + center=self.center, + auto_bound=self.auto_bound) + + # rotate segs + for key in results.get('seg_fields', []): + results[key] = mmcv.imrotate( + results[key], + angle=degree, + border_value=self.seg_pad_val, + center=self.center, + auto_bound=self.auto_bound, + interpolation='nearest') + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' \ + f'degree={self.degree}, ' \ + f'pad_val={self.pal_val}, ' \ + f'seg_pad_val={self.seg_pad_val}, ' \ + f'center={self.center}, ' \ + f'auto_bound={self.auto_bound})' + return repr_str + + +@PIPELINES.register_module() +class RGB2Gray(object): + """Convert RGB image to grayscale image. + + This transform calculate the weighted mean of input image channels with + ``weights`` and then expand the channels to ``out_channels``. When + ``out_channels`` is None, the number of output channels is the same as + input channels. + + Args: + out_channels (int): Expected number of output channels after + transforming. Default: None. + weights (tuple[float]): The weights to calculate the weighted mean. + Default: (0.299, 0.587, 0.114). + """ + + def __init__(self, out_channels=None, weights=(0.299, 0.587, 0.114)): + assert out_channels is None or out_channels > 0 + self.out_channels = out_channels + assert isinstance(weights, tuple) + for item in weights: + assert isinstance(item, (float, int)) + self.weights = weights + + def __call__(self, results): + """Call function to convert RGB image to grayscale image. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with grayscale image. + """ + img = results['img'] + assert len(img.shape) == 3 + assert img.shape[2] == len(self.weights) + weights = np.array(self.weights).reshape((1, 1, -1)) + img = (img * weights).sum(2, keepdims=True) + if self.out_channels is None: + img = img.repeat(weights.shape[2], axis=2) + else: + img = img.repeat(self.out_channels, axis=2) + + results['img'] = img + results['img_shape'] = img.shape + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(out_channels={self.out_channels}, ' \ + f'weights={self.weights})' + return repr_str + + +@PIPELINES.register_module() +class AdjustGamma(object): + """Using gamma correction to process the image. + + Args: + gamma (float or int): Gamma value used in gamma correction. + Default: 1.0. + """ + + def __init__(self, gamma=1.0): + assert isinstance(gamma, float) or isinstance(gamma, int) + assert gamma > 0 + self.gamma = gamma + inv_gamma = 1.0 / gamma + self.table = np.array([(i / 255.0)**inv_gamma * 255 + for i in np.arange(256)]).astype('uint8') + + def __call__(self, results): + """Call function to process the image with gamma correction. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Processed results. + """ + + results['img'] = mmcv.lut_transform( + np.array(results['img'], dtype=np.uint8), self.table) + + return results + + def __repr__(self): + return self.__class__.__name__ + f'(gamma={self.gamma})' + +@PIPELINES.register_module() +class MaillaryHack(object): + """ map MV 65 class to 19 class like Cityscapes + """ + def __init__(self): + self.map = [[13, 24, 41], [2, 15], [17], [6], [3], [45, 47], [48], [50], [30], [29], + [27], [19], [20, 21, 22], [55], [61], [54], [58], [57], [52]] + + self.others = [i for i in range(66)] + for i in self.map: + for j in i: + if j in self.others: + self.others.remove(j) + + + def __call__(self, results): + """Call function to process the image with gamma correction. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Processed results. + """ + gt_map = results['gt_semantic_seg'] + # others -> 255 + for value in self.others: + gt_map[gt_map == value] = 255 + + for index, map in enumerate(self.map): + for value in map: + gt_map[gt_map == value] = index + + results['gt_semantic_seg'] = gt_map + + return results + + def __repr__(self): + return 'MaillaryHack' + + +@PIPELINES.register_module() +class SegRescale(object): + """Rescale semantic segmentation maps. + + Args: + scale_factor (float): The scale factor of the final output. + """ + + def __init__(self, scale_factor=1): + self.scale_factor = scale_factor + + def __call__(self, results): + """Call function to scale the semantic segmentation map. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with semantic segmentation map scaled. + """ + for key in results.get('seg_fields', []): + if self.scale_factor != 1: + results[key] = mmcv.imrescale( + results[key], self.scale_factor, interpolation='nearest') + return results + + def __repr__(self): + return self.__class__.__name__ + f'(scale_factor={self.scale_factor})' + + +@PIPELINES.register_module() +class PhotoMetricDistortion(object): + """Apply photometric distortion to image sequentially, every transformation + is applied with a probability of 0.5. The position of random contrast is in + second or second to last. + + 1. random brightness + 2. random contrast (mode 0) + 3. convert color from BGR to HSV + 4. random saturation + 5. random hue + 6. convert color from HSV to BGR + 7. random contrast (mode 1) + 8. randomly swap channels + + Args: + brightness_delta (int): delta of brightness. + contrast_range (tuple): range of contrast. + saturation_range (tuple): range of saturation. + hue_delta (int): delta of hue. + """ + + def __init__(self, + brightness_delta=32, + contrast_range=(0.5, 1.5), + saturation_range=(0.5, 1.5), + hue_delta=18): + self.brightness_delta = brightness_delta + self.contrast_lower, self.contrast_upper = contrast_range + self.saturation_lower, self.saturation_upper = saturation_range + self.hue_delta = hue_delta + + def convert(self, img, alpha=1, beta=0): + """Multiple with alpha and add beat with clip.""" + img = img.astype(np.float32) * alpha + beta + img = np.clip(img, 0, 255) + return img.astype(np.uint8) + + def brightness(self, img): + """Brightness distortion.""" + if random.randint(2): + return self.convert( + img, + beta=random.uniform(-self.brightness_delta, + self.brightness_delta)) + return img + + def contrast(self, img): + """Contrast distortion.""" + if random.randint(2): + return self.convert( + img, + alpha=random.uniform(self.contrast_lower, self.contrast_upper)) + return img + + def saturation(self, img): + """Saturation distortion.""" + if random.randint(2): + img = mmcv.bgr2hsv(img) + img[:, :, 1] = self.convert( + img[:, :, 1], + alpha=random.uniform(self.saturation_lower, + self.saturation_upper)) + img = mmcv.hsv2bgr(img) + return img + + def hue(self, img): + """Hue distortion.""" + if random.randint(2): + img = mmcv.bgr2hsv(img) + img[:, :, + 0] = (img[:, :, 0].astype(int) + + random.randint(-self.hue_delta, self.hue_delta)) % 180 + img = mmcv.hsv2bgr(img) + return img + + def __call__(self, results): + """Call function to perform photometric distortion on images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Result dict with images distorted. + """ + + img = results['img'] + # random brightness + img = self.brightness(img) + + # mode == 0 --> do random contrast first + # mode == 1 --> do random contrast last + mode = random.randint(2) + if mode == 1: + img = self.contrast(img) + + # random saturation + img = self.saturation(img) + + # random hue + img = self.hue(img) + + # random contrast + if mode == 0: + img = self.contrast(img) + + results['img'] = img + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += (f'(brightness_delta={self.brightness_delta}, ' + f'contrast_range=({self.contrast_lower}, ' + f'{self.contrast_upper}), ' + f'saturation_range=({self.saturation_lower}, ' + f'{self.saturation_upper}), ' + f'hue_delta={self.hue_delta})') + return repr_str diff --git a/models/mmseg/datasets/stare.py b/models/mmseg/datasets/stare.py new file mode 100644 index 0000000000000000000000000000000000000000..cbd14e0920e7f6a73baff1432e5a32ccfdb0dfae --- /dev/null +++ b/models/mmseg/datasets/stare.py @@ -0,0 +1,27 @@ +import os.path as osp + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class STAREDataset(CustomDataset): + """STARE dataset. + + In segmentation map annotation for STARE, 0 stands for background, which is + included in 2 categories. ``reduce_zero_label`` is fixed to False. The + ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to + '.ah.png'. + """ + + CLASSES = ('background', 'vessel') + + PALETTE = [[120, 120, 120], [6, 230, 230]] + + def __init__(self, **kwargs): + super(STAREDataset, self).__init__( + img_suffix='.png', + seg_map_suffix='.ah.png', + reduce_zero_label=False, + **kwargs) + assert osp.exists(self.img_dir) diff --git a/models/mmseg/datasets/voc.py b/models/mmseg/datasets/voc.py new file mode 100644 index 0000000000000000000000000000000000000000..a8855203b14ee0dc4da9099a2945d4aedcffbcd6 --- /dev/null +++ b/models/mmseg/datasets/voc.py @@ -0,0 +1,29 @@ +import os.path as osp + +from .builder import DATASETS +from .custom import CustomDataset + + +@DATASETS.register_module() +class PascalVOCDataset(CustomDataset): + """Pascal VOC dataset. + + Args: + split (str): Split txt file for Pascal VOC. + """ + + CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', + 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', + 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', + 'train', 'tvmonitor') + + PALETTE = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], + [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], + [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], + [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], + [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]] + + def __init__(self, split, **kwargs): + super(PascalVOCDataset, self).__init__( + img_suffix='.jpg', seg_map_suffix='.png', split=split, **kwargs) + assert osp.exists(self.img_dir) and self.split is not None diff --git a/models/mmseg/models/__init__.py b/models/mmseg/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3f407c08690e3cbe6db79f3d691a9888a6f94585 --- /dev/null +++ b/models/mmseg/models/__init__.py @@ -0,0 +1,7 @@ +from .builder import build_loss +from .losses import * # noqa: F401,F403 + + +__all__ = [ + 'build_loss' +] diff --git a/models/mmseg/models/__pycache__/__init__.cpython-37.pyc b/models/mmseg/models/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84b170ff64fc31e305ede45f67808f970dd51ef3 Binary files /dev/null and b/models/mmseg/models/__pycache__/__init__.cpython-37.pyc differ diff --git a/models/mmseg/models/__pycache__/builder.cpython-37.pyc b/models/mmseg/models/__pycache__/builder.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c36de3714b80d5cd5d52a10353bcaed1614344e Binary files /dev/null and b/models/mmseg/models/__pycache__/builder.cpython-37.pyc differ diff --git a/models/mmseg/models/builder.py b/models/mmseg/models/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..8940888c3a4dc4c3494b9fd328e14f176f53baf3 --- /dev/null +++ b/models/mmseg/models/builder.py @@ -0,0 +1,68 @@ +import warnings + +from mmcv.utils import Registry, build_from_cfg +from torch import nn +BACKBONES = Registry('backbone') +NECKS = Registry('neck') +HEADS = Registry('head') +LOSSES = Registry('loss') +SEGMENTORS = Registry('segmentor') + +# from mmseg.models.builder import BACKBONES +# bbb = BACKBONES.get('mit_b0') +# print(bbb) + +def build(cfg, registry, default_args=None): + """Build a module. + + Args: + cfg (dict, list[dict]): The config of modules, is is either a dict + or a list of configs. + registry (:obj:`Registry`): A registry the module belongs to. + default_args (dict, optional): Default arguments to build the module. + Defaults to None. + + Returns: + nn.Module: A built nn module. + """ + + if isinstance(cfg, list): + modules = [ + build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg + ] + return nn.Sequential(*modules) + else: + return build_from_cfg(cfg, registry, default_args) + + +def build_backbone(cfg): + """Build backbone.""" + return build(cfg, BACKBONES) + + +def build_neck(cfg): + """Build neck.""" + return build(cfg, NECKS) + + +def build_head(cfg): + """Build head.""" + return build(cfg, HEADS) + + +def build_loss(cfg): + """Build loss.""" + return build(cfg, LOSSES) + + +def build_segmentor(cfg, train_cfg=None, test_cfg=None): + """Build segmentor.""" + if train_cfg is not None or test_cfg is not None: + warnings.warn( + 'train_cfg and test_cfg is deprecated, ' + 'please specify them in model', UserWarning) + assert cfg.get('train_cfg') is None or train_cfg is None, \ + 'train_cfg specified in both outer field and model field ' + assert cfg.get('test_cfg') is None or test_cfg is None, \ + 'test_cfg specified in both outer field and model field ' + return build(cfg, SEGMENTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg)) diff --git a/models/mmseg/models/losses/__init__.py b/models/mmseg/models/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d623887760654c5c07fbdb2c76012baa2f9a4b52 --- /dev/null +++ b/models/mmseg/models/losses/__init__.py @@ -0,0 +1,11 @@ +from .accuracy import Accuracy, accuracy +from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, + cross_entropy, mask_cross_entropy) +from .lovasz_loss import LovaszLoss +from .utils import reduce_loss, weight_reduce_loss, weighted_loss + +__all__ = [ + 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy', + 'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss', + 'weight_reduce_loss', 'weighted_loss', 'LovaszLoss' +] diff --git a/models/mmseg/models/losses/__pycache__/__init__.cpython-37.pyc b/models/mmseg/models/losses/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b16b2e4691c5033ebe59fa37ddcca20728e3931 Binary files /dev/null and b/models/mmseg/models/losses/__pycache__/__init__.cpython-37.pyc differ diff --git a/models/mmseg/models/losses/__pycache__/accuracy.cpython-37.pyc b/models/mmseg/models/losses/__pycache__/accuracy.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac4b1203ada06f5b239253c74d3b1071da1a467f Binary files /dev/null and b/models/mmseg/models/losses/__pycache__/accuracy.cpython-37.pyc differ diff --git a/models/mmseg/models/losses/__pycache__/cross_entropy_loss.cpython-37.pyc b/models/mmseg/models/losses/__pycache__/cross_entropy_loss.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d5011be2991d0ef829b74507cafdd9605fda033 Binary files /dev/null and b/models/mmseg/models/losses/__pycache__/cross_entropy_loss.cpython-37.pyc differ diff --git a/models/mmseg/models/losses/__pycache__/lovasz_loss.cpython-37.pyc b/models/mmseg/models/losses/__pycache__/lovasz_loss.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f03f96bd008bf973a97af612055678180eba2c7 Binary files /dev/null and b/models/mmseg/models/losses/__pycache__/lovasz_loss.cpython-37.pyc differ diff --git a/models/mmseg/models/losses/__pycache__/utils.cpython-37.pyc b/models/mmseg/models/losses/__pycache__/utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da51576225831bb7036c7fe7e5588fe0b80a526a Binary files /dev/null and b/models/mmseg/models/losses/__pycache__/utils.cpython-37.pyc differ diff --git a/models/mmseg/models/losses/accuracy.py b/models/mmseg/models/losses/accuracy.py new file mode 100644 index 0000000000000000000000000000000000000000..e45f9ec485737ef1f6717eaf3b6ddc572a169932 --- /dev/null +++ b/models/mmseg/models/losses/accuracy.py @@ -0,0 +1,78 @@ +import torch.nn as nn + + +def accuracy(pred, target, topk=1, thresh=None): + """Calculate accuracy according to the prediction and target. + + Args: + pred (torch.Tensor): The model prediction, shape (N, num_class, ...) + target (torch.Tensor): The target of each prediction, shape (N, , ...) + topk (int | tuple[int], optional): If the predictions in ``topk`` + matches the target, the predictions will be regarded as + correct ones. Defaults to 1. + thresh (float, optional): If not None, predictions with scores under + this threshold are considered incorrect. Default to None. + + Returns: + float | tuple[float]: If the input ``topk`` is a single integer, + the function will return a single float as accuracy. If + ``topk`` is a tuple containing multiple integers, the + function will return a tuple containing accuracies of + each ``topk`` number. + """ + assert isinstance(topk, (int, tuple)) + if isinstance(topk, int): + topk = (topk, ) + return_single = True + else: + return_single = False + + maxk = max(topk) + if pred.size(0) == 0: + accu = [pred.new_tensor(0.) for i in range(len(topk))] + return accu[0] if return_single else accu + assert pred.ndim == target.ndim + 1 + assert pred.size(0) == target.size(0) + assert maxk <= pred.size(1), \ + f'maxk {maxk} exceeds pred dimension {pred.size(1)}' + pred_value, pred_label = pred.topk(maxk, dim=1) + # transpose to shape (maxk, N, ...) + pred_label = pred_label.transpose(0, 1) + correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label)) + if thresh is not None: + # Only prediction values larger than thresh are counted as correct + correct = correct & (pred_value > thresh).t() + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / target.numel())) + return res[0] if return_single else res + + +class Accuracy(nn.Module): + """Accuracy calculation module.""" + + def __init__(self, topk=(1, ), thresh=None): + """Module to calculate the accuracy. + + Args: + topk (tuple, optional): The criterion used to calculate the + accuracy. Defaults to (1,). + thresh (float, optional): If not None, predictions with scores + under this threshold are considered incorrect. Default to None. + """ + super().__init__() + self.topk = topk + self.thresh = thresh + + def forward(self, pred, target): + """Forward function to calculate accuracy. + + Args: + pred (torch.Tensor): Prediction of models. + target (torch.Tensor): Target for each prediction. + + Returns: + tuple[float]: The accuracies under different topk criterions. + """ + return accuracy(pred, target, self.topk, self.thresh) diff --git a/models/mmseg/models/losses/cross_entropy_loss.py b/models/mmseg/models/losses/cross_entropy_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..44798421aaced24d0524dbd3618645fd7ebb1e86 --- /dev/null +++ b/models/mmseg/models/losses/cross_entropy_loss.py @@ -0,0 +1,198 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..builder import LOSSES +from .utils import weight_reduce_loss + + +def cross_entropy(pred, + label, + weight=None, + class_weight=None, + reduction='mean', + avg_factor=None, + ignore_index=-100): + """The wrapper function for :func:`F.cross_entropy`""" + # class_weight is a manual rescaling weight given to each class. + # If given, has to be a Tensor of size C element-wise losses + loss = F.cross_entropy( + pred, + label, + weight=class_weight, + reduction='none', + ignore_index=ignore_index) + + # apply weights and do the reduction + if weight is not None: + weight = weight.float() + loss = weight_reduce_loss( + loss, weight=weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + +def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index): + """Expand onehot labels to match the size of prediction.""" + bin_labels = labels.new_zeros(target_shape) + valid_mask = (labels >= 0) & (labels != ignore_index) + inds = torch.nonzero(valid_mask, as_tuple=True) + + if inds[0].numel() > 0: + if labels.dim() == 3: + bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1 + else: + bin_labels[inds[0], labels[valid_mask]] = 1 + + valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float() + if label_weights is None: + bin_label_weights = valid_mask + else: + bin_label_weights = label_weights.unsqueeze(1).expand(target_shape) + bin_label_weights *= valid_mask + + return bin_labels, bin_label_weights + + +def binary_cross_entropy(pred, + label, + weight=None, + reduction='mean', + avg_factor=None, + class_weight=None, + ignore_index=255): + """Calculate the binary CrossEntropy loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, 1). + label (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, optional): Sample-wise loss weight. + reduction (str, optional): The method used to reduce the loss. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (list[float], optional): The weight for each class. + ignore_index (int | None): The label index to be ignored. Default: 255 + + Returns: + torch.Tensor: The calculated loss + """ + if pred.dim() != label.dim(): + assert (pred.dim() == 2 and label.dim() == 1) or ( + pred.dim() == 4 and label.dim() == 3), \ + 'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \ + 'H, W], label shape [N, H, W] are supported' + label, weight = _expand_onehot_labels(label, weight, pred.shape, + ignore_index) + + # weighted element-wise losses + if weight is not None: + weight = weight.float() + loss = F.binary_cross_entropy_with_logits( + pred, label.float(), pos_weight=class_weight, reduction='none') + # do the reduction for the weighted loss + loss = weight_reduce_loss( + loss, weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + +def mask_cross_entropy(pred, + target, + label, + reduction='mean', + avg_factor=None, + class_weight=None, + ignore_index=None): + """Calculate the CrossEntropy loss for masks. + + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the number + of classes. + target (torch.Tensor): The learning label of the prediction. + label (torch.Tensor): ``label`` indicates the class label of the mask' + corresponding object. This will be used to select the mask in the + of the class which the object belongs to when the mask prediction + if not class-agnostic. + reduction (str, optional): The method used to reduce the loss. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (list[float], optional): The weight for each class. + ignore_index (None): Placeholder, to be consistent with other loss. + Default: None. + + Returns: + torch.Tensor: The calculated loss + """ + assert ignore_index is None, 'BCE loss does not support ignore_index' + # TODO: handle these two reserved arguments + assert reduction == 'mean' and avg_factor is None + num_rois = pred.size()[0] + inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) + pred_slice = pred[inds, label].squeeze(1) + return F.binary_cross_entropy_with_logits( + pred_slice, target, weight=class_weight, reduction='mean')[None] + + +@LOSSES.register_module() +class CrossEntropyLoss(nn.Module): + """CrossEntropyLoss. + + Args: + use_sigmoid (bool, optional): Whether the prediction uses sigmoid + of softmax. Defaults to False. + use_mask (bool, optional): Whether to use mask cross entropy loss. + Defaults to False. + reduction (str, optional): . Defaults to 'mean'. + Options are "none", "mean" and "sum". + class_weight (list[float], optional): Weight of each class. + Defaults to None. + loss_weight (float, optional): Weight of the loss. Defaults to 1.0. + """ + + def __init__(self, + use_sigmoid=False, + use_mask=False, + reduction='mean', + class_weight=None, + loss_weight=1.0): + super(CrossEntropyLoss, self).__init__() + assert (use_sigmoid is False) or (use_mask is False) + self.use_sigmoid = use_sigmoid + self.use_mask = use_mask + self.reduction = reduction + self.loss_weight = loss_weight + self.class_weight = class_weight + + if self.use_sigmoid: + self.cls_criterion = binary_cross_entropy + elif self.use_mask: + self.cls_criterion = mask_cross_entropy + else: + self.cls_criterion = cross_entropy + + def forward(self, + cls_score, + label, + weight=None, + avg_factor=None, + reduction_override=None, + **kwargs): + """Forward function.""" + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if self.class_weight is not None: + class_weight = cls_score.new_tensor(self.class_weight) + else: + class_weight = None + loss_cls = self.loss_weight * self.cls_criterion( + cls_score, + label, + weight, + class_weight=class_weight, + reduction=reduction, + avg_factor=avg_factor, + **kwargs) + return loss_cls diff --git a/models/mmseg/models/losses/lovasz_loss.py b/models/mmseg/models/losses/lovasz_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..e6e2450cfdc713ee1ea166886469da4541568329 --- /dev/null +++ b/models/mmseg/models/losses/lovasz_loss.py @@ -0,0 +1,303 @@ +"""Modified from https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytor +ch/lovasz_losses.py Lovasz-Softmax and Jaccard hinge loss in PyTorch Maxim +Berman 2018 ESAT-PSI KU Leuven (MIT License)""" + +import mmcv +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..builder import LOSSES +from .utils import weight_reduce_loss + + +def lovasz_grad(gt_sorted): + """Computes gradient of the Lovasz extension w.r.t sorted errors. + + See Alg. 1 in paper. + """ + p = len(gt_sorted) + gts = gt_sorted.sum() + intersection = gts - gt_sorted.float().cumsum(0) + union = gts + (1 - gt_sorted).float().cumsum(0) + jaccard = 1. - intersection / union + if p > 1: # cover 1-pixel case + jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] + return jaccard + + +def flatten_binary_logits(logits, labels, ignore_index=None): + """Flattens predictions in the batch (binary case) Remove labels equal to + 'ignore_index'.""" + logits = logits.view(-1) + labels = labels.view(-1) + if ignore_index is None: + return logits, labels + valid = (labels != ignore_index) + vlogits = logits[valid] + vlabels = labels[valid] + return vlogits, vlabels + + +def flatten_probs(probs, labels, ignore_index=None): + """Flattens predictions in the batch.""" + if probs.dim() == 3: + # assumes output of a sigmoid layer + B, H, W = probs.size() + probs = probs.view(B, 1, H, W) + B, C, H, W = probs.size() + probs = probs.permute(0, 2, 3, 1).contiguous().view(-1, C) # B*H*W, C=P,C + labels = labels.view(-1) + if ignore_index is None: + return probs, labels + valid = (labels != ignore_index) + vprobs = probs[valid.nonzero().squeeze()] + vlabels = labels[valid] + return vprobs, vlabels + + +def lovasz_hinge_flat(logits, labels): + """Binary Lovasz hinge loss. + + Args: + logits (torch.Tensor): [P], logits at each prediction + (between -infty and +infty). + labels (torch.Tensor): [P], binary ground truth labels (0 or 1). + + Returns: + torch.Tensor: The calculated loss. + """ + if len(labels) == 0: + # only void pixels, the gradients should be 0 + return logits.sum() * 0. + signs = 2. * labels.float() - 1. + errors = (1. - logits * signs) + errors_sorted, perm = torch.sort(errors, dim=0, descending=True) + perm = perm.data + gt_sorted = labels[perm] + grad = lovasz_grad(gt_sorted) + loss = torch.dot(F.relu(errors_sorted), grad) + return loss + + +def lovasz_hinge(logits, + labels, + classes='present', + per_image=False, + class_weight=None, + reduction='mean', + avg_factor=None, + ignore_index=255): + """Binary Lovasz hinge loss. + + Args: + logits (torch.Tensor): [B, H, W], logits at each pixel + (between -infty and +infty). + labels (torch.Tensor): [B, H, W], binary ground truth masks (0 or 1). + classes (str | list[int], optional): Placeholder, to be consistent with + other loss. Default: None. + per_image (bool, optional): If per_image is True, compute the loss per + image instead of per batch. Default: False. + class_weight (list[float], optional): Placeholder, to be consistent + with other loss. Default: None. + reduction (str, optional): The method used to reduce the loss. Options + are "none", "mean" and "sum". This parameter only works when + per_image is True. Default: 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. This parameter only works when per_image is True. + Default: None. + ignore_index (int | None): The label index to be ignored. Default: 255. + + Returns: + torch.Tensor: The calculated loss. + """ + if per_image: + loss = [ + lovasz_hinge_flat(*flatten_binary_logits( + logit.unsqueeze(0), label.unsqueeze(0), ignore_index)) + for logit, label in zip(logits, labels) + ] + loss = weight_reduce_loss( + torch.stack(loss), None, reduction, avg_factor) + else: + loss = lovasz_hinge_flat( + *flatten_binary_logits(logits, labels, ignore_index)) + return loss + + +def lovasz_softmax_flat(probs, labels, classes='present', class_weight=None): + """Multi-class Lovasz-Softmax loss. + + Args: + probs (torch.Tensor): [P, C], class probabilities at each prediction + (between 0 and 1). + labels (torch.Tensor): [P], ground truth labels (between 0 and C - 1). + classes (str | list[int], optional): Classes choosed to calculate loss. + 'all' for all classes, 'present' for classes present in labels, or + a list of classes to average. Default: 'present'. + class_weight (list[float], optional): The weight for each class. + Default: None. + + Returns: + torch.Tensor: The calculated loss. + """ + if probs.numel() == 0: + # only void pixels, the gradients should be 0 + return probs * 0. + C = probs.size(1) + losses = [] + class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes + for c in class_to_sum: + fg = (labels == c).float() # foreground for class c + if (classes == 'present' and fg.sum() == 0): + continue + if C == 1: + if len(classes) > 1: + raise ValueError('Sigmoid output possible only with 1 class') + class_pred = probs[:, 0] + else: + class_pred = probs[:, c] + errors = (fg - class_pred).abs() + errors_sorted, perm = torch.sort(errors, 0, descending=True) + perm = perm.data + fg_sorted = fg[perm] + loss = torch.dot(errors_sorted, lovasz_grad(fg_sorted)) + if class_weight is not None: + loss *= class_weight[c] + losses.append(loss) + return torch.stack(losses).mean() + + +def lovasz_softmax(probs, + labels, + classes='present', + per_image=False, + class_weight=None, + reduction='mean', + avg_factor=None, + ignore_index=255): + """Multi-class Lovasz-Softmax loss. + + Args: + probs (torch.Tensor): [B, C, H, W], class probabilities at each + prediction (between 0 and 1). + labels (torch.Tensor): [B, H, W], ground truth labels (between 0 and + C - 1). + classes (str | list[int], optional): Classes choosed to calculate loss. + 'all' for all classes, 'present' for classes present in labels, or + a list of classes to average. Default: 'present'. + per_image (bool, optional): If per_image is True, compute the loss per + image instead of per batch. Default: False. + class_weight (list[float], optional): The weight for each class. + Default: None. + reduction (str, optional): The method used to reduce the loss. Options + are "none", "mean" and "sum". This parameter only works when + per_image is True. Default: 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. This parameter only works when per_image is True. + Default: None. + ignore_index (int | None): The label index to be ignored. Default: 255. + + Returns: + torch.Tensor: The calculated loss. + """ + + if per_image: + loss = [ + lovasz_softmax_flat( + *flatten_probs( + prob.unsqueeze(0), label.unsqueeze(0), ignore_index), + classes=classes, + class_weight=class_weight) + for prob, label in zip(probs, labels) + ] + loss = weight_reduce_loss( + torch.stack(loss), None, reduction, avg_factor) + else: + loss = lovasz_softmax_flat( + *flatten_probs(probs, labels, ignore_index), + classes=classes, + class_weight=class_weight) + return loss + + +@LOSSES.register_module() +class LovaszLoss(nn.Module): + """LovaszLoss. + + This loss is proposed in `The Lovasz-Softmax loss: A tractable surrogate + for the optimization of the intersection-over-union measure in neural + networks `_. + + Args: + loss_type (str, optional): Binary or multi-class loss. + Default: 'multi_class'. Options are "binary" and "multi_class". + classes (str | list[int], optional): Classes choosed to calculate loss. + 'all' for all classes, 'present' for classes present in labels, or + a list of classes to average. Default: 'present'. + per_image (bool, optional): If per_image is True, compute the loss per + image instead of per batch. Default: False. + reduction (str, optional): The method used to reduce the loss. Options + are "none", "mean" and "sum". This parameter only works when + per_image is True. Default: 'mean'. + class_weight (list[float], optional): The weight for each class. + Default: None. + loss_weight (float, optional): Weight of the loss. Defaults to 1.0. + """ + + def __init__(self, + loss_type='multi_class', + classes='present', + per_image=False, + reduction='mean', + class_weight=None, + loss_weight=1.0): + super(LovaszLoss, self).__init__() + assert loss_type in ('binary', 'multi_class'), "loss_type should be \ + 'binary' or 'multi_class'." + + if loss_type == 'binary': + self.cls_criterion = lovasz_hinge + else: + self.cls_criterion = lovasz_softmax + assert classes in ('all', 'present') or mmcv.is_list_of(classes, int) + if not per_image: + assert reduction == 'none', "reduction should be 'none' when \ + per_image is False." + + self.classes = classes + self.per_image = per_image + self.reduction = reduction + self.loss_weight = loss_weight + self.class_weight = class_weight + + def forward(self, + cls_score, + label, + weight=None, + avg_factor=None, + reduction_override=None, + **kwargs): + """Forward function.""" + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if self.class_weight is not None: + class_weight = cls_score.new_tensor(self.class_weight) + else: + class_weight = None + + # if multi-class loss, transform logits to probs + if self.cls_criterion == lovasz_softmax: + cls_score = F.softmax(cls_score, dim=1) + + loss_cls = self.loss_weight * self.cls_criterion( + cls_score, + label, + self.classes, + self.per_image, + class_weight=class_weight, + reduction=reduction, + avg_factor=avg_factor, + **kwargs) + return loss_cls diff --git a/models/mmseg/models/losses/utils.py b/models/mmseg/models/losses/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a1153fa9f39f1f045def43b52bb55a06f301ff22 --- /dev/null +++ b/models/mmseg/models/losses/utils.py @@ -0,0 +1,101 @@ +import functools + +import torch.nn.functional as F + + +def reduce_loss(loss, reduction): + """Reduce loss as specified. + + Args: + loss (Tensor): Elementwise loss tensor. + reduction (str): Options are "none", "mean" and "sum". + + Return: + Tensor: Reduced loss tensor. + """ + reduction_enum = F._Reduction.get_enum(reduction) + # none: 0, elementwise_mean:1, sum: 2 + if reduction_enum == 0: + return loss + elif reduction_enum == 1: + return loss.mean() + elif reduction_enum == 2: + return loss.sum() + + +def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): + """Apply element-wise weight and reduce loss. + + Args: + loss (Tensor): Element-wise loss. + weight (Tensor): Element-wise weights. + reduction (str): Same as built-in losses of PyTorch. + avg_factor (float): Avarage factor when computing the mean of losses. + + Returns: + Tensor: Processed loss values. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + assert weight.dim() == loss.dim() + if weight.dim() > 1: + assert weight.size(1) == 1 or weight.size(1) == loss.size(1) + loss = loss * weight + + # if avg_factor is not specified, just reduce the loss + if avg_factor is None: + loss = reduce_loss(loss, reduction) + else: + # if reduction is mean, then average the loss by avg_factor + if reduction == 'mean': + loss = loss.sum() / avg_factor + # if reduction is 'none', then do nothing, otherwise raise an error + elif reduction != 'none': + raise ValueError('avg_factor can not be used with reduction="sum"') + return loss + + +def weighted_loss(loss_func): + """Create a weighted version of a given loss function. + + To use this decorator, the loss function must have the signature like + `loss_func(pred, target, **kwargs)`. The function only needs to compute + element-wise loss without any reduction. This decorator will add weight + and reduction arguments to the function. The decorated function will have + the signature like `loss_func(pred, target, weight=None, reduction='mean', + avg_factor=None, **kwargs)`. + + :Example: + + >>> import torch + >>> @weighted_loss + >>> def l1_loss(pred, target): + >>> return (pred - target).abs() + + >>> pred = torch.Tensor([0, 2, 3]) + >>> target = torch.Tensor([1, 1, 1]) + >>> weight = torch.Tensor([1, 0, 1]) + + >>> l1_loss(pred, target) + tensor(1.3333) + >>> l1_loss(pred, target, weight) + tensor(1.) + >>> l1_loss(pred, target, reduction='none') + tensor([1., 1., 2.]) + >>> l1_loss(pred, target, weight, avg_factor=2) + tensor(1.5000) + """ + + @functools.wraps(loss_func) + def wrapper(pred, + target, + weight=None, + reduction='mean', + avg_factor=None, + **kwargs): + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + return wrapper diff --git a/models/mmseg/models/sam/__init__.py b/models/mmseg/models/sam/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..38e906243d898d7fc071c0fe218338c5cace3ea1 --- /dev/null +++ b/models/mmseg/models/sam/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .sam import Sam +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder +from .transformer import TwoWayTransformer diff --git a/models/mmseg/models/sam/__pycache__/__init__.cpython-37.pyc b/models/mmseg/models/sam/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81bd65da289d022dcdd3dc73eb8336c187d5f543 Binary files /dev/null and b/models/mmseg/models/sam/__pycache__/__init__.cpython-37.pyc differ diff --git a/models/mmseg/models/sam/__pycache__/common.cpython-37.pyc b/models/mmseg/models/sam/__pycache__/common.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..abd3061d211171df2442c3abe210586f2ce7e397 Binary files /dev/null and b/models/mmseg/models/sam/__pycache__/common.cpython-37.pyc differ diff --git a/models/mmseg/models/sam/__pycache__/image_encoder.cpython-37.pyc b/models/mmseg/models/sam/__pycache__/image_encoder.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40a5d3c7188811e161b7180bff330e784163c8a3 Binary files /dev/null and b/models/mmseg/models/sam/__pycache__/image_encoder.cpython-37.pyc differ diff --git a/models/mmseg/models/sam/__pycache__/mask_decoder.cpython-37.pyc b/models/mmseg/models/sam/__pycache__/mask_decoder.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4fbffdcefed6828ff56a77be54c54949c1ae19b Binary files /dev/null and b/models/mmseg/models/sam/__pycache__/mask_decoder.cpython-37.pyc differ diff --git a/models/mmseg/models/sam/__pycache__/prompt_encoder.cpython-37.pyc b/models/mmseg/models/sam/__pycache__/prompt_encoder.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cede2ea147b40796af0a631bdce07b772380cd7 Binary files /dev/null and b/models/mmseg/models/sam/__pycache__/prompt_encoder.cpython-37.pyc differ diff --git a/models/mmseg/models/sam/__pycache__/sam.cpython-37.pyc b/models/mmseg/models/sam/__pycache__/sam.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a82223b92574937afb512c277bf7ee1fcb68a1d Binary files /dev/null and b/models/mmseg/models/sam/__pycache__/sam.cpython-37.pyc differ diff --git a/models/mmseg/models/sam/__pycache__/transformer.cpython-37.pyc b/models/mmseg/models/sam/__pycache__/transformer.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f73d8d6efcffb14dae45882923f89153bc98556 Binary files /dev/null and b/models/mmseg/models/sam/__pycache__/transformer.cpython-37.pyc differ diff --git a/models/mmseg/models/sam/common.py b/models/mmseg/models/sam/common.py new file mode 100644 index 0000000000000000000000000000000000000000..2bf15236a3eb24d8526073bc4fa2b274cccb3f96 --- /dev/null +++ b/models/mmseg/models/sam/common.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from typing import Type + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/models/mmseg/models/sam/image_encoder.py b/models/mmseg/models/sam/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..d7dad6d90429325b64a251991f9581ef71b97200 --- /dev/null +++ b/models/mmseg/models/sam/image_encoder.py @@ -0,0 +1,644 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Type + +from .common import LayerNorm2d, MLPBlock +import math +import warnings +from itertools import repeat +TORCH_MAJOR = int(torch.__version__.split('.')[0]) +TORCH_MINOR = int(torch.__version__.split('.')[1]) +if TORCH_MAJOR == 1 and TORCH_MINOR < 8: + from torch._six import container_abcs +else: + import collections.abc as container_abcs + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ + super().__init__() + self.img_size = img_size + self.embed_dim = embed_dim + self.depth = depth + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + )#图像转成patch + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + )#tranformer block with window attention + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + )#1*1卷积 + LN + 3*3卷积 + LN + + self.scale_factor = 32 + self.prompt_type = 'highpass' + self.tuning_stage = 1234 + self.input_type = 'fft' + self.freq_nums = 0.25 + self.handcrafted_tune = True + self.embedding_tune = True + self.adaptor = 'adaptor' + self.prompt_generator = PromptGenerator(self.scale_factor, self.prompt_type, self.embed_dim, + self.tuning_stage, self.depth, + self.input_type, self.freq_nums, + self.handcrafted_tune, self.embedding_tune, self.adaptor, + img_size, patch_size) + self.num_stages = self.depth + self.out_indices = tuple(range(self.num_stages)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + inp = x + x = self.patch_embed(x)#图像转为patch + + embedding_feature = self.prompt_generator.init_embeddings(x)#N*C*(H*W) + handcrafted_feature = self.prompt_generator.init_handcrafted(inp) + prompt = self.prompt_generator.get_prompt(handcrafted_feature, embedding_feature) + if self.pos_embed is not None:#绝对的位置embedding + x = x + self.pos_embed + + B, H, W = x.shape[0], x.shape[1], x.shape[2] + outs = [] + for i, blk in enumerate(self.blocks): + x = prompt[i].reshape(B, H, W, -1) + x #prompt与x相加 + x = blk(x) #transformer block + if i in self.out_indices: + outs.append(x) #多级输出 + + x = self.neck(x.permute(0, 3, 1, 2)) #转为N*C*H*W, 1*1卷积 + LN + 3*3卷积 + LN + + return x + +def to_2tuple(x): + if isinstance(x, container_abcs.Iterable): + return x + return tuple(repeat(x, 2)) + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +class PromptGenerator(nn.Module): + def __init__(self, scale_factor, prompt_type, embed_dim, tuning_stage, depth, input_type, + freq_nums, handcrafted_tune, embedding_tune, adaptor, img_size, patch_size): + """ + Args: + """ + super(PromptGenerator, self).__init__() + self.scale_factor = scale_factor + self.prompt_type = prompt_type + self.embed_dim = embed_dim + self.input_type = input_type + self.freq_nums = freq_nums + self.tuning_stage = tuning_stage + self.depth = depth + self.handcrafted_tune = handcrafted_tune + self.embedding_tune = embedding_tune + self.adaptor = adaptor + + self.shared_mlp = nn.Linear(self.embed_dim//self.scale_factor, self.embed_dim) + self.embedding_generator = nn.Linear(self.embed_dim, self.embed_dim//self.scale_factor) + for i in range(self.depth): + lightweight_mlp = nn.Sequential( + nn.Linear(self.embed_dim//self.scale_factor, self.embed_dim//self.scale_factor), + nn.GELU() + ) + setattr(self, 'lightweight_mlp_{}'.format(str(i)), lightweight_mlp) + + self.prompt_generator = PatchEmbed2(img_size=img_size, + patch_size=patch_size, in_chans=3, + embed_dim=self.embed_dim//self.scale_factor)#patch_size的大卷积 + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def init_embeddings(self, x): + N, C, H, W = x.permute(0, 3, 1, 2).shape + x = x.reshape(N, C, H*W).permute(0, 2, 1) + return self.embedding_generator(x)#降维 + + def init_handcrafted(self, x): + x = self.fft(x, self.freq_nums)#傅里叶相关变换 + return self.prompt_generator(x)#大卷积,N*C*H*W + + def get_prompt(self, handcrafted_feature, embedding_feature): + N, C, H, W = handcrafted_feature.shape + handcrafted_feature = handcrafted_feature.view(N, C, H*W).permute(0, 2, 1) + prompts = [] + for i in range(self.depth): + lightweight_mlp = getattr(self, 'lightweight_mlp_{}'.format(str(i))) + # prompt = proj_prompt(prompt) + prompt = lightweight_mlp(handcrafted_feature + embedding_feature)#两个张量相加,做fc+Gelu + prompts.append(self.shared_mlp(prompt))#升维 + return prompts + + def forward(self, x): + if self.input_type == 'laplacian': + pyr_A = self.lap_pyramid.pyramid_decom(img=x, num=self.freq_nums) + x = pyr_A[:-1] + laplacian = x[0] + for x_i in x[1:]: + x_i = F.interpolate(x_i, size=(laplacian.size(2), laplacian.size(3)), mode='bilinear', align_corners=True) + laplacian = torch.cat([laplacian, x_i], dim=1) + x = laplacian + elif self.input_type == 'fft': + x = self.fft(x, self.freq_nums) + elif self.input_type == 'all': + x = self.prompt.unsqueeze(0).repeat(x.shape[0], 1, 1, 1) + + # get prompting + prompt = self.prompt_generator(x)#对图像做patch_size的大卷积 + + if self.mode == 'input': + prompt = self.proj(prompt) + return prompt + elif self.mode == 'stack': + prompts = [] + for i in range(self.depth): + proj = getattr(self, 'proj_{}'.format(str(i))) + prompts.append(proj(prompt)) + return prompts + elif self.mode == 'hierarchical': + prompts = [] + for i in range(self.depth): + proj_prompt = getattr(self, 'proj_prompt_{}'.format(str(i))) + prompt = proj_prompt(prompt) + prompts.append(self.proj_token(prompt)) + return prompts + + def fft(self, x, rate):#rate为0.25 + # the smaller rate, the smoother; the larger rate, the darker + # rate = 4, 8, 16, 32 + mask = torch.zeros(x.shape).to(x.device) + w, h = x.shape[-2:] + line = int((w * h * rate) ** .5 // 2) + mask[:, :, w//2-line:w//2+line, h//2-line:h//2+line] = 1 + + fft = torch.fft.fftshift(torch.fft.fft2(x, norm="forward"))#傅里叶变换相关 + # mask[fft.float() > self.freq_nums] = 1 + # high pass: 1-mask, low pass: mask + fft = fft * (1 - mask) + # fft = fft * mask + fr = fft.real + fi = fft.imag + + fft_hires = torch.fft.ifftshift(torch.complex(fr, fi)) + inv = torch.fft.ifft2(fft_hires, norm="forward").real + + inv = torch.abs(inv) + + return inv + +class PatchEmbed2(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * \ + (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, + kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + + # x = F.interpolate(x, size=2*x.shape[-1], mode='bilinear', align_corners=True) + x = self.proj(x) + return x + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + + +def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/models/mmseg/models/sam/mask_decoder.py b/models/mmseg/models/sam/mask_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a41f902dc6e451e6d94a8a721e5d7882ec0debe9 --- /dev/null +++ b/models/mmseg/models/sam/mask_decoder.py @@ -0,0 +1,189 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import List, Tuple, Type + +from .common import LayerNorm2d + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth + ) + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + """ + masks, iou_pred = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + ) + + # Select the correct mask or masks for outptu + # False + # if multimask_output: + # mask_slice = slice(1, None) + # else: + # mask_slice = slice(0, 1) + # print('****before slice:::',masks.shape) + # masks = masks[:, mask_slice, :, :] + + # iou_pred = iou_pred[:, mask_slice] + # print('****behind slice:::', masks.shape) + # Prepare output + + # masks = masks[:, :6] + # iou_pred = iou_pred[:, :6] + return masks, iou_pred + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + #修改 + if image_embeddings.shape[0] != tokens.shape[0]: + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + # print('----!=----', image_embeddings.shape, tokens.shape) + else: + src = image_embeddings + # src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)#重复张量元素 + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + upscaled_embedding = self.output_upscaling(src) + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + # print('**mask::::', masks.shape) + #b*4*1024*1024 + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/models/mmseg/models/sam/prompt_encoder.py b/models/mmseg/models/sam/prompt_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..6023be7dce2f2d0a343c7701b8a4b53c00c19d89 --- /dev/null +++ b/models/mmseg/models/sam/prompt_encoder.py @@ -0,0 +1,254 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch import nn + +from typing import Any, Optional, Tuple, Type + +from .common import LayerNorm2d + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.scatter_downscaling = nn.Sequential( + nn.Conv2d(3, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + # corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) + # corner_embedding[:, 1, :] += self.point_embeddings[3].weight + # return corner_embedding + + coords_list_ = [] + for i_ in range(boxes.shape[0]): + one_coords = coords[i_] + one_coords = one_coords.reshape(1, 2, 2) + one_corner_embedding = self.pe_layer.forward_with_coords(one_coords, self.input_image_size) + one_corner_embedding[:, 0, :] += self.point_embeddings[2].weight + one_corner_embedding[:, 1, :] += self.point_embeddings[3].weight + coords_list_.append(one_corner_embedding) + # print(one_corner_embedding.shape) + + coords_embedding = torch.cat(coords_list_, dim=0) + coords_embedding = coords_embedding.reshape(1, -1, one_corner_embedding.shape[-1]) + return coords_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _embed_scatter(self, scatter: torch.Tensor) -> torch.Tensor: + scatter_embedding = self.scatter_downscaling(scatter) + return scatter_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + scatter: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + # bs = self._get_batch_size(points, boxes, masks) + bs = 1 + # print('bs::::', bs, boxes) + sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) #bs*0*256 + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + if boxes.shape[0] != 0: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + else: + box_embeddings = None + + # print('spa::::', sparse_embeddings.shape) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + if scatter is not None: + scatter_embeddings = self._embed_scatter(scatter) + else: + scatter_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings, scatter_embeddings + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C diff --git a/models/mmseg/models/sam/sam.py b/models/mmseg/models/sam/sam.py new file mode 100644 index 0000000000000000000000000000000000000000..8074cff6b40addc6b66f7ab4962218eef20da13c --- /dev/null +++ b/models/mmseg/models/sam/sam.py @@ -0,0 +1,174 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import Any, Dict, List, Tuple + +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder + + +class Sam(nn.Module): + mask_threshold: float = 0.0 + image_format: str = "RGB" + + def __init__( + self, + image_encoder: ImageEncoderViT, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = [123.675, 116.28, 103.53], + pixel_std: List[float] = [58.395, 57.12, 57.375], + ) -> None: + """ + SAM predicts object masks from an image and input prompts. + + Arguments: + image_encoder (ImageEncoderViT): The backbone used to encode the + image into image embeddings that allow for efficient mask prediction. + prompt_encoder (PromptEncoder): Encodes various types of input prompts. + mask_decoder (MaskDecoder): Predicts masks from the image embeddings + and encoded prompts. + pixel_mean (list(float)): Mean values for normalizing pixels in the input image. + pixel_std (list(float)): Std values for normalizing pixels in the input image. + """ + super().__init__() + self.image_encoder = image_encoder + self.prompt_encoder = prompt_encoder + self.mask_decoder = mask_decoder + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + @property + def device(self) -> Any: + return self.pixel_mean.device + + @torch.no_grad() + def forward( + self, + batched_input: List[Dict[str, Any]], + multimask_output: bool, + ) -> List[Dict[str, torch.Tensor]]: + """ + Predicts masks end-to-end from provided images and prompts. + If prompts are not known in advance, using SamPredictor is + recommended over calling the model directly. + + Arguments: + batched_input (list(dict)): A list over input images, each a + dictionary with the following keys. A prompt key can be + excluded if it is not present. + 'image': The image as a torch tensor in 3xHxW format, + already transformed for input to the model. + 'original_size': (tuple(int, int)) The original size of + the image before transformation, as (H, W). + 'point_coords': (torch.Tensor) Batched point prompts for + this image, with shape BxNx2. Already transformed to the + input frame of the model. + 'point_labels': (torch.Tensor) Batched labels for point prompts, + with shape BxN. + 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. + Already transformed to the input frame of the model. + 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, + in the form Bx1xHxW. + multimask_output (bool): Whether the model should predict multiple + disambiguating masks, or return a single mask. + + Returns: + (list(dict)): A list over input images, where each element is + as dictionary with the following keys. + 'masks': (torch.Tensor) Batched binary mask predictions, + with shape BxCxHxW, where B is the number of input prompts, + C is determined by multimask_output, and (H, W) is the + original size of the image. + 'iou_predictions': (torch.Tensor) The model's predictions + of mask quality, in shape BxC. + 'low_res_logits': (torch.Tensor) Low resolution logits with + shape BxCxHxW, where H=W=256. Can be passed as mask input + to subsequent iterations of prediction. + """ + input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) + image_embeddings = self.image_encoder(input_images) + + outputs = [] + for image_record, curr_embedding in zip(batched_input, image_embeddings): + if "point_coords" in image_record: + points = (image_record["point_coords"], image_record["point_labels"]) + else: + points = None + sparse_embeddings, dense_embeddings = self.prompt_encoder( + points=points, + boxes=image_record.get("boxes", None), + masks=image_record.get("mask_inputs", None), + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=curr_embedding.unsqueeze(0), + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + masks = self.postprocess_masks( + low_res_masks, + input_size=image_record["image"].shape[-2:], + original_size=image_record["original_size"], + ) + masks = masks > self.mask_threshold + outputs.append( + { + "masks": masks, + "iou_predictions": iou_predictions, + "low_res_logits": low_res_masks, + } + ) + return outputs + + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: Tuple[int, ...], + original_size: Tuple[int, ...], + ) -> torch.Tensor: + """ + Remove padding and upscale masks to the original image size. + + Arguments: + masks (torch.Tensor): Batched masks from the mask_decoder, + in BxCxHxW format. + input_size (tuple(int, int)): The size of the image input to the + model, in (H, W) format. Used to remove padding. + original_size (tuple(int, int)): The original size of the image + before resizing for input to the model, in (H, W) format. + + Returns: + (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) + is given by original_size. + """ + masks = F.interpolate( + masks, + (self.image_encoder.img_size, self.image_encoder.img_size), + mode="bilinear", + align_corners=False, + ) + masks = masks[..., : input_size[0], : input_size[1]] + masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) + return masks + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.image_encoder.img_size - h + padw = self.image_encoder.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x diff --git a/models/mmseg/models/sam/transformer.py b/models/mmseg/models/sam/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..28fafea52288603fea275f3a100790471825c34a --- /dev/null +++ b/models/mmseg/models/sam/transformer.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import Tensor, nn + +import math +from typing import Tuple, Type + +from .common import MLPBlock + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/models/mmseg/models/utils/__init__.py b/models/mmseg/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..413228626e42eccbec6e099fd6490f4aec614394 --- /dev/null +++ b/models/mmseg/models/utils/__init__.py @@ -0,0 +1,10 @@ +from .inverted_residual import InvertedResidual, InvertedResidualV3 +from .make_divisible import make_divisible +from .res_layer import ResLayer +from .self_attention_block import SelfAttentionBlock +from .up_conv_block import UpConvBlock + +__all__ = [ + 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual', + 'UpConvBlock', 'InvertedResidualV3' +] diff --git a/models/mmseg/models/utils/drop.py b/models/mmseg/models/utils/drop.py new file mode 100644 index 0000000000000000000000000000000000000000..31789083cf006fc6b0c91373066856c6a0f48280 --- /dev/null +++ b/models/mmseg/models/utils/drop.py @@ -0,0 +1,159 @@ +""" DropBlock, DropPath +PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers. +Papers: +DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890) +Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382) +Code: +DropBlock impl inspired by two Tensorflow impl that I liked: + - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74 + - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def drop_block_2d( + x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0, + with_noise: bool = False, inplace: bool = False, batchwise: bool = False): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + DropBlock with an experimental gaussian noise option. This layer has been tested on a few training + runs with success, but needs further validation and possibly optimization for lower runtime impact. + """ + B, C, H, W = x.shape + total_size = W * H + clipped_block_size = min(block_size, min(W, H)) + # seed_drop_rate, the gamma parameter + gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( + (W - block_size + 1) * (H - block_size + 1)) + + # Forces the block to be inside the feature map. + w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device)) + valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \ + ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2)) + valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype) + + if batchwise: + # one mask for whole batch, quite a bit faster + uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) + else: + uniform_noise = torch.rand_like(x) + block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype) + block_mask = -F.max_pool2d( + -block_mask, + kernel_size=clipped_block_size, # block_size, + stride=1, + padding=clipped_block_size // 2) + + if with_noise: + normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x) + if inplace: + x.mul_(block_mask).add_(normal_noise * (1 - block_mask)) + else: + x = x * block_mask + normal_noise * (1 - block_mask) + else: + normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype) + if inplace: + x.mul_(block_mask * normalize_scale) + else: + x = x * block_mask * normalize_scale + return x + + +def drop_block_fast_2d( + x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7, + gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, batchwise: bool = False): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid + block mask at edges. + """ + B, C, H, W = x.shape + total_size = W * H + clipped_block_size = min(block_size, min(W, H)) + gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( + (W - block_size + 1) * (H - block_size + 1)) + + if batchwise: + # one mask for whole batch, quite a bit faster + block_mask = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) < gamma + else: + # mask per batch element + block_mask = torch.rand_like(x) < gamma + block_mask = F.max_pool2d( + block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2) + + if with_noise: + normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x) + if inplace: + x.mul_(1. - block_mask).add_(normal_noise * block_mask) + else: + x = x * (1. - block_mask) + normal_noise * block_mask + else: + block_mask = 1 - block_mask + normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(dtype=x.dtype) + if inplace: + x.mul_(block_mask * normalize_scale) + else: + x = x * block_mask * normalize_scale + return x + + +class DropBlock2d(nn.Module): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + """ + def __init__(self, + drop_prob=0.1, + block_size=7, + gamma_scale=1.0, + with_noise=False, + inplace=False, + batchwise=False, + fast=True): + super(DropBlock2d, self).__init__() + self.drop_prob = drop_prob + self.gamma_scale = gamma_scale + self.block_size = block_size + self.with_noise = with_noise + self.inplace = inplace + self.batchwise = batchwise + self.fast = fast # FIXME finish comparisons of fast vs not + + def forward(self, x): + if not self.training or not self.drop_prob: + return x + if self.fast: + return drop_block_fast_2d( + x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise) + else: + return drop_block_2d( + x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise) + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/models/mmseg/models/utils/inverted_residual.py b/models/mmseg/models/utils/inverted_residual.py new file mode 100644 index 0000000000000000000000000000000000000000..093388f5643a1027095ef62e0158448df725fb82 --- /dev/null +++ b/models/mmseg/models/utils/inverted_residual.py @@ -0,0 +1,208 @@ +from mmcv.cnn import ConvModule +from torch import nn as nn +from torch.utils import checkpoint as cp + +from .se_layer import SELayer + + +class InvertedResidual(nn.Module): + """InvertedResidual block for MobileNetV2. + + Args: + in_channels (int): The input channels of the InvertedResidual block. + out_channels (int): The output channels of the InvertedResidual block. + stride (int): Stride of the middle (first) 3x3 convolution. + expand_ratio (int): Adjusts number of channels of the hidden layer + in InvertedResidual by this amount. + dilation (int): Dilation rate of depthwise conv. Default: 1 + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + + Returns: + Tensor: The output tensor. + """ + + def __init__(self, + in_channels, + out_channels, + stride, + expand_ratio, + dilation=1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6'), + with_cp=False): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2], f'stride must in [1, 2]. ' \ + f'But received {stride}.' + self.with_cp = with_cp + self.use_res_connect = self.stride == 1 and in_channels == out_channels + hidden_dim = int(round(in_channels * expand_ratio)) + + layers = [] + if expand_ratio != 1: + layers.append( + ConvModule( + in_channels=in_channels, + out_channels=hidden_dim, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + layers.extend([ + ConvModule( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + groups=hidden_dim, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=hidden_dim, + out_channels=out_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + + def _inner_forward(x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +class InvertedResidualV3(nn.Module): + """Inverted Residual Block for MobileNetV3. + + Args: + in_channels (int): The input channels of this Module. + out_channels (int): The output channels of this Module. + mid_channels (int): The input channels of the depthwise convolution. + kernel_size (int): The kernal size of the depthwise convolution. + Default: 3. + stride (int): The stride of the depthwise convolution. Default: 1. + se_cfg (dict): Config dict for se layer. Defaul: None, which means no + se layer. + with_expand_conv (bool): Use expand conv or not. If set False, + mid_channels must be the same with in_channels. Default: True. + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + + Returns: + Tensor: The output tensor. + """ + + def __init__(self, + in_channels, + out_channels, + mid_channels, + kernel_size=3, + stride=1, + se_cfg=None, + with_expand_conv=True, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + with_cp=False): + super(InvertedResidualV3, self).__init__() + self.with_res_shortcut = (stride == 1 and in_channels == out_channels) + assert stride in [1, 2] + self.with_cp = with_cp + self.with_se = se_cfg is not None + self.with_expand_conv = with_expand_conv + + if self.with_se: + assert isinstance(se_cfg, dict) + if not self.with_expand_conv: + assert mid_channels == in_channels + + if self.with_expand_conv: + self.expand_conv = ConvModule( + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.depthwise_conv = ConvModule( + in_channels=mid_channels, + out_channels=mid_channels, + kernel_size=kernel_size, + stride=stride, + padding=kernel_size // 2, + groups=mid_channels, + conv_cfg=dict( + type='Conv2dAdaptivePadding') if stride == 2 else conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + if self.with_se: + self.se = SELayer(**se_cfg) + + self.linear_conv = ConvModule( + in_channels=mid_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + + def forward(self, x): + + def _inner_forward(x): + out = x + + if self.with_expand_conv: + out = self.expand_conv(out) + + out = self.depthwise_conv(out) + + if self.with_se: + out = self.se(out) + + out = self.linear_conv(out) + + if self.with_res_shortcut: + return x + out + else: + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out diff --git a/models/mmseg/models/utils/make_divisible.py b/models/mmseg/models/utils/make_divisible.py new file mode 100644 index 0000000000000000000000000000000000000000..75ad756052529f52fe83bb95dd1f0ecfc9a13078 --- /dev/null +++ b/models/mmseg/models/utils/make_divisible.py @@ -0,0 +1,27 @@ +def make_divisible(value, divisor, min_value=None, min_ratio=0.9): + """Make divisible function. + + This function rounds the channel number to the nearest value that can be + divisible by the divisor. It is taken from the original tf repo. It ensures + that all layers have a channel number that is divisible by divisor. It can + be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa + + Args: + value (int): The original channel number. + divisor (int): The divisor to fully divide the channel number. + min_value (int): The minimum value of the output channel. + Default: None, means that the minimum value equal to the divisor. + min_ratio (float): The minimum ratio of the rounded channel number to + the original channel number. Default: 0.9. + + Returns: + int: The modified output channel number. + """ + + if min_value is None: + min_value = divisor + new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than (1-min_ratio). + if new_value < min_ratio * value: + new_value += divisor + return new_value diff --git a/models/mmseg/models/utils/norm.py b/models/mmseg/models/utils/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..616373c3c1d0e9dc9cac51f85d791346e2240c99 --- /dev/null +++ b/models/mmseg/models/utils/norm.py @@ -0,0 +1,60 @@ +import torch +import math +import warnings + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) \ No newline at end of file diff --git a/models/mmseg/models/utils/res_layer.py b/models/mmseg/models/utils/res_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..2585ab551aea79252ef6b34b5faef476e9e1abaa --- /dev/null +++ b/models/mmseg/models/utils/res_layer.py @@ -0,0 +1,94 @@ +from mmcv.cnn import build_conv_layer, build_norm_layer +from torch import nn as nn + + +class ResLayer(nn.Sequential): + """ResLayer to build ResNet style backbone. + + Args: + block (nn.Module): block used to build ResLayer. + inplanes (int): inplanes of block. + planes (int): planes of block. + num_blocks (int): number of blocks. + stride (int): stride of the first block. Default: 1 + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Default: False + conv_cfg (dict): dictionary to construct and config conv layer. + Default: None + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN') + multi_grid (int | None): Multi grid dilation rates of last + stage. Default: None + contract_dilation (bool): Whether contract first dilation of each layer + Default: False + """ + + def __init__(self, + block, + inplanes, + planes, + num_blocks, + stride=1, + dilation=1, + avg_down=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + multi_grid=None, + contract_dilation=False, + **kwargs): + self.block = block + + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = [] + conv_stride = stride + if avg_down: + conv_stride = 1 + downsample.append( + nn.AvgPool2d( + kernel_size=stride, + stride=stride, + ceil_mode=True, + count_include_pad=False)) + downsample.extend([ + build_conv_layer( + conv_cfg, + inplanes, + planes * block.expansion, + kernel_size=1, + stride=conv_stride, + bias=False), + build_norm_layer(norm_cfg, planes * block.expansion)[1] + ]) + downsample = nn.Sequential(*downsample) + + layers = [] + if multi_grid is None: + if dilation > 1 and contract_dilation: + first_dilation = dilation // 2 + else: + first_dilation = dilation + else: + first_dilation = multi_grid[0] + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=stride, + dilation=first_dilation, + downsample=downsample, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + **kwargs)) + inplanes = planes * block.expansion + for i in range(1, num_blocks): + layers.append( + block( + inplanes=inplanes, + planes=planes, + stride=1, + dilation=dilation if multi_grid is None else multi_grid[i], + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + **kwargs)) + super(ResLayer, self).__init__(*layers) diff --git a/models/mmseg/models/utils/se_layer.py b/models/mmseg/models/utils/se_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..d75e712cb467eab3e099e900cb3c7a2911a337d5 --- /dev/null +++ b/models/mmseg/models/utils/se_layer.py @@ -0,0 +1,57 @@ +import mmcv +import torch.nn as nn +from mmcv.cnn import ConvModule + +from .make_divisible import make_divisible + + +class SELayer(nn.Module): + """Squeeze-and-Excitation Module. + + Args: + channels (int): The input (and output) channels of the SE layer. + ratio (int): Squeeze ratio in SELayer, the intermediate channel will be + ``int(channels/ratio)``. Default: 16. + conv_cfg (None or dict): Config dict for convolution layer. + Default: None, which means using conv2d. + act_cfg (dict or Sequence[dict]): Config dict for activation layer. + If act_cfg is a dict, two activation layers will be configurated + by this dict. If act_cfg is a sequence of dicts, the first + activation layer will be configurated by the first dict and the + second activation layer will be configurated by the second dict. + Default: (dict(type='ReLU'), dict(type='HSigmoid', bias=3.0, + divisor=6.0)). + """ + + def __init__(self, + channels, + ratio=16, + conv_cfg=None, + act_cfg=(dict(type='ReLU'), + dict(type='HSigmoid', bias=3.0, divisor=6.0))): + super(SELayer, self).__init__() + if isinstance(act_cfg, dict): + act_cfg = (act_cfg, act_cfg) + assert len(act_cfg) == 2 + assert mmcv.is_tuple_of(act_cfg, dict) + self.global_avgpool = nn.AdaptiveAvgPool2d(1) + self.conv1 = ConvModule( + in_channels=channels, + out_channels=make_divisible(channels // ratio, 8), + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + act_cfg=act_cfg[0]) + self.conv2 = ConvModule( + in_channels=make_divisible(channels // ratio, 8), + out_channels=channels, + kernel_size=1, + stride=1, + conv_cfg=conv_cfg, + act_cfg=act_cfg[1]) + + def forward(self, x): + out = self.global_avgpool(x) + out = self.conv1(out) + out = self.conv2(out) + return x * out diff --git a/models/mmseg/models/utils/self_attention_block.py b/models/mmseg/models/utils/self_attention_block.py new file mode 100644 index 0000000000000000000000000000000000000000..372fad2e000a157c7ef283a82388a7fea0158ee0 --- /dev/null +++ b/models/mmseg/models/utils/self_attention_block.py @@ -0,0 +1,159 @@ +import torch +from mmcv.cnn import ConvModule, constant_init +from torch import nn as nn +from torch.nn import functional as F + + +class SelfAttentionBlock(nn.Module): + """General self-attention block/non-local block. + + Please refer to https://arxiv.org/abs/1706.03762 for details about key, + query and value. + + Args: + key_in_channels (int): Input channels of key feature. + query_in_channels (int): Input channels of query feature. + channels (int): Output channels of key/query transform. + out_channels (int): Output channels. + share_key_query (bool): Whether share projection weight between key + and query projection. + query_downsample (nn.Module): Query downsample module. + key_downsample (nn.Module): Key downsample module. + key_query_num_convs (int): Number of convs for key/query projection. + value_num_convs (int): Number of convs for value projection. + matmul_norm (bool): Whether normalize attention map with sqrt of + channels + with_out (bool): Whether use out projection. + conv_cfg (dict|None): Config of conv layers. + norm_cfg (dict|None): Config of norm layers. + act_cfg (dict|None): Config of activation layers. + """ + + def __init__(self, key_in_channels, query_in_channels, channels, + out_channels, share_key_query, query_downsample, + key_downsample, key_query_num_convs, value_out_num_convs, + key_query_norm, value_out_norm, matmul_norm, with_out, + conv_cfg, norm_cfg, act_cfg): + super(SelfAttentionBlock, self).__init__() + if share_key_query: + assert key_in_channels == query_in_channels + self.key_in_channels = key_in_channels + self.query_in_channels = query_in_channels + self.out_channels = out_channels + self.channels = channels + self.share_key_query = share_key_query + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.key_project = self.build_project( + key_in_channels, + channels, + num_convs=key_query_num_convs, + use_conv_module=key_query_norm, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + if share_key_query: + self.query_project = self.key_project + else: + self.query_project = self.build_project( + query_in_channels, + channels, + num_convs=key_query_num_convs, + use_conv_module=key_query_norm, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.value_project = self.build_project( + key_in_channels, + channels if with_out else out_channels, + num_convs=value_out_num_convs, + use_conv_module=value_out_norm, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + if with_out: + self.out_project = self.build_project( + channels, + out_channels, + num_convs=value_out_num_convs, + use_conv_module=value_out_norm, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + else: + self.out_project = None + + self.query_downsample = query_downsample + self.key_downsample = key_downsample + self.matmul_norm = matmul_norm + + self.init_weights() + + def init_weights(self): + """Initialize weight of later layer.""" + if self.out_project is not None: + if not isinstance(self.out_project, ConvModule): + constant_init(self.out_project, 0) + + def build_project(self, in_channels, channels, num_convs, use_conv_module, + conv_cfg, norm_cfg, act_cfg): + """Build projection layer for key/query/value/out.""" + if use_conv_module: + convs = [ + ConvModule( + in_channels, + channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + ] + for _ in range(num_convs - 1): + convs.append( + ConvModule( + channels, + channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + else: + convs = [nn.Conv2d(in_channels, channels, 1)] + for _ in range(num_convs - 1): + convs.append(nn.Conv2d(channels, channels, 1)) + if len(convs) > 1: + convs = nn.Sequential(*convs) + else: + convs = convs[0] + return convs + + def forward(self, query_feats, key_feats): + """Forward function.""" + batch_size = query_feats.size(0) + query = self.query_project(query_feats) + if self.query_downsample is not None: + query = self.query_downsample(query) + query = query.reshape(*query.shape[:2], -1) + query = query.permute(0, 2, 1).contiguous() + + key = self.key_project(key_feats) + value = self.value_project(key_feats) + if self.key_downsample is not None: + key = self.key_downsample(key) + value = self.key_downsample(value) + key = key.reshape(*key.shape[:2], -1) + value = value.reshape(*value.shape[:2], -1) + value = value.permute(0, 2, 1).contiguous() + + sim_map = torch.matmul(query, key) + if self.matmul_norm: + sim_map = (self.channels**-.5) * sim_map + sim_map = F.softmax(sim_map, dim=-1) + + context = torch.matmul(sim_map, value) + context = context.permute(0, 2, 1).contiguous() + context = context.reshape(batch_size, -1, *query_feats.shape[2:]) + if self.out_project is not None: + context = self.out_project(context) + return context diff --git a/models/mmseg/models/utils/up_conv_block.py b/models/mmseg/models/utils/up_conv_block.py new file mode 100644 index 0000000000000000000000000000000000000000..df8a2aa7db31cf80d3c75adc4ca6da3155a75890 --- /dev/null +++ b/models/mmseg/models/utils/up_conv_block.py @@ -0,0 +1,101 @@ +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, build_upsample_layer + + +class UpConvBlock(nn.Module): + """Upsample convolution block in decoder for UNet. + + This upsample convolution block consists of one upsample module + followed by one convolution block. The upsample module expands the + high-level low-resolution feature map and the convolution block fuses + the upsampled high-level low-resolution feature map and the low-level + high-resolution feature map from encoder. + + Args: + conv_block (nn.Sequential): Sequential of convolutional layers. + in_channels (int): Number of input channels of the high-level + skip_channels (int): Number of input channels of the low-level + high-resolution feature map from encoder. + out_channels (int): Number of output channels. + num_convs (int): Number of convolutional layers in the conv_block. + Default: 2. + stride (int): Stride of convolutional layer in conv_block. Default: 1. + dilation (int): Dilation rate of convolutional layer in conv_block. + Default: 1. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + upsample_cfg (dict): The upsample config of the upsample module in + decoder. Default: dict(type='InterpConv'). If the size of + high-level feature map is the same as that of skip feature map + (low-level feature map from encoder), it does not need upsample the + high-level feature map and the upsample_cfg is None. + dcn (bool): Use deformable convoluton in convolutional layer or not. + Default: None. + plugins (dict): plugins for convolutional layers. Default: None. + """ + + def __init__(self, + conv_block, + in_channels, + skip_channels, + out_channels, + num_convs=2, + stride=1, + dilation=1, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + upsample_cfg=dict(type='InterpConv'), + dcn=None, + plugins=None): + super(UpConvBlock, self).__init__() + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + + self.conv_block = conv_block( + in_channels=2 * skip_channels, + out_channels=out_channels, + num_convs=num_convs, + stride=stride, + dilation=dilation, + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + dcn=None, + plugins=None) + if upsample_cfg is not None: + self.upsample = build_upsample_layer( + cfg=upsample_cfg, + in_channels=in_channels, + out_channels=skip_channels, + with_cp=with_cp, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + else: + self.upsample = ConvModule( + in_channels, + skip_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, skip, x): + """Forward function.""" + + x = self.upsample(x) + out = torch.cat([skip, x], dim=1) + out = self.conv_block(out) + + return out diff --git a/models/mmseg/ops/__init__.py b/models/mmseg/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bec51c75b9363a9a19e9fb5c35f4e7dbd6f7751c --- /dev/null +++ b/models/mmseg/ops/__init__.py @@ -0,0 +1,4 @@ +from .encoding import Encoding +from .wrappers import Upsample, resize + +__all__ = ['Upsample', 'resize', 'Encoding'] diff --git a/models/mmseg/ops/encoding.py b/models/mmseg/ops/encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..d939189657576b3bb13f6fdb512775aa97737d24 --- /dev/null +++ b/models/mmseg/ops/encoding.py @@ -0,0 +1,74 @@ +import torch +from torch import nn as nn +from torch.nn import functional as F + + +class Encoding(nn.Module): + """Encoding Layer: a learnable residual encoder. + + Input is of shape (batch_size, channels, height, width). + Output is of shape (batch_size, num_codes, channels). + + Args: + channels: dimension of the features or feature channels + num_codes: number of code words + """ + + def __init__(self, channels, num_codes): + super(Encoding, self).__init__() + # init codewords and smoothing factor + self.channels, self.num_codes = channels, num_codes + std = 1. / ((num_codes * channels)**0.5) + # [num_codes, channels] + self.codewords = nn.Parameter( + torch.empty(num_codes, channels, + dtype=torch.float).uniform_(-std, std), + requires_grad=True) + # [num_codes] + self.scale = nn.Parameter( + torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0), + requires_grad=True) + + @staticmethod + def scaled_l2(x, codewords, scale): + num_codes, channels = codewords.size() + batch_size = x.size(0) + reshaped_scale = scale.view((1, 1, num_codes)) + expanded_x = x.unsqueeze(2).expand( + (batch_size, x.size(1), num_codes, channels)) + reshaped_codewords = codewords.view((1, 1, num_codes, channels)) + + scaled_l2_norm = reshaped_scale * ( + expanded_x - reshaped_codewords).pow(2).sum(dim=3) + return scaled_l2_norm + + @staticmethod + def aggregate(assigment_weights, x, codewords): + num_codes, channels = codewords.size() + reshaped_codewords = codewords.view((1, 1, num_codes, channels)) + batch_size = x.size(0) + + expanded_x = x.unsqueeze(2).expand( + (batch_size, x.size(1), num_codes, channels)) + encoded_feat = (assigment_weights.unsqueeze(3) * + (expanded_x - reshaped_codewords)).sum(dim=1) + return encoded_feat + + def forward(self, x): + assert x.dim() == 4 and x.size(1) == self.channels + # [batch_size, channels, height, width] + batch_size = x.size(0) + # [batch_size, height x width, channels] + x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous() + # assignment_weights: [batch_size, channels, num_codes] + assigment_weights = F.softmax( + self.scaled_l2(x, self.codewords, self.scale), dim=2) + # aggregate + encoded_feat = self.aggregate(assigment_weights, x, self.codewords) + return encoded_feat + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \ + f'x{self.channels})' + return repr_str diff --git a/models/mmseg/ops/wrappers.py b/models/mmseg/ops/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..a6d755273df1af99371502186aef27e8e353ec9f --- /dev/null +++ b/models/mmseg/ops/wrappers.py @@ -0,0 +1,53 @@ +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def resize(input, + size=None, + scale_factor=None, + mode='nearest', + align_corners=None, + warning=True): + if warning: + if size is not None and align_corners: + input_h, input_w = tuple(int(x) for x in input.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) + if output_h > input_h or output_w > output_h: + if ((output_h > 1 and output_w > 1 and input_h > 1 + and input_w > 1) and (output_h - 1) % (input_h - 1) + and (output_w - 1) % (input_w - 1)): + warnings.warn( + f'When align_corners={align_corners}, ' + 'the output would more aligned if ' + f'input size {(input_h, input_w)} is `x+1` and ' + f'out size {(output_h, output_w)} is `nx+1`') + if isinstance(size, torch.Size): + size = tuple(int(x) for x in size) + return F.interpolate(input, size, scale_factor, mode, align_corners) + + +class Upsample(nn.Module): + + def __init__(self, + size=None, + scale_factor=None, + mode='nearest', + align_corners=None): + super(Upsample, self).__init__() + self.size = size + if isinstance(scale_factor, tuple): + self.scale_factor = tuple(float(factor) for factor in scale_factor) + else: + self.scale_factor = float(scale_factor) if scale_factor else None + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + if not self.size: + size = [int(t * self.scale_factor) for t in x.shape[-2:]] + else: + size = self.size + return resize(x, size, None, self.mode, self.align_corners) diff --git a/models/mmseg/utils/__init__.py b/models/mmseg/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..615f2820e948fc9a2b7622cf13f33e797ef92afa --- /dev/null +++ b/models/mmseg/utils/__init__.py @@ -0,0 +1,4 @@ +from .collect_env import collect_env +from .logger import get_root_logger, print_log + +__all__ = ['get_root_logger', 'collect_env', 'print_log'] diff --git a/models/mmseg/utils/collect_env.py b/models/mmseg/utils/collect_env.py new file mode 100644 index 0000000000000000000000000000000000000000..8293a05fb3422249485c4d877749fc9c92a51e36 --- /dev/null +++ b/models/mmseg/utils/collect_env.py @@ -0,0 +1,17 @@ +from mmcv.utils import collect_env as collect_base_env +from mmcv.utils import get_git_hash + +import mmseg + + +def collect_env(): + """Collect the information of the running environments.""" + env_info = collect_base_env() + env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}' + + return env_info + + +if __name__ == '__main__': + for name, val in collect_env().items(): + print('{}: {}'.format(name, val)) diff --git a/models/mmseg/utils/logger.py b/models/mmseg/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..14515843985486942eb2814bbeac2250feff0a6b --- /dev/null +++ b/models/mmseg/utils/logger.py @@ -0,0 +1,51 @@ +import logging + +from mmcv.utils import get_logger + + +def get_root_logger(log_file=None, log_level=logging.INFO): + """Get the root logger. + + The logger will be initialized if it has not been initialized. By default a + StreamHandler will be added. If `log_file` is specified, a FileHandler will + also be added. The name of the root logger is the top-level package name, + e.g., "mmseg". + + Args: + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the root logger. + log_level (int): The root logger level. Note that only the process of + rank 0 is affected, while other processes will set the level to + "Error" and be silent most of the time. + + Returns: + logging.Logger: The root logger. + """ + + logger = get_logger(name='mmseg', log_file=log_file, log_level=log_level) + + return logger + +def print_log(msg, logger=None, level=logging.INFO): + """Print a log message. + Args: + msg (str): The message to be logged. + logger (logging.Logger | str | None): The logger to be used. Some + special loggers are: + - "root": the root logger obtained with `get_root_logger()`. + - "silent": no message will be printed. + - None: The `print()` method will be used to print log messages. + level (int): Logging level. Only available when `logger` is a Logger + object or "root". + """ + if logger is None: + print(msg) + elif logger == 'root': + _logger = get_root_logger() + _logger.log(level, msg) + elif isinstance(logger, logging.Logger): + logger.log(level, msg) + elif logger != 'silent': + raise TypeError( + 'logger should be either a logging.Logger object, "root", ' + '"silent" or None, but got {}'.format(logger)) \ No newline at end of file diff --git a/models/mmseg/version.py b/models/mmseg/version.py new file mode 100644 index 0000000000000000000000000000000000000000..41a08cf155f9b938972b331c216349ef848840a5 --- /dev/null +++ b/models/mmseg/version.py @@ -0,0 +1,18 @@ +# Copyright (c) Open-MMLab. All rights reserved. + +__version__ = '0.11.0' + + +def parse_version_info(version_str): + version_info = [] + for x in version_str.split('.'): + if x.isdigit(): + version_info.append(int(x)) + elif x.find('rc') != -1: + patch_version = x.split('rc') + version_info.append(int(patch_version[0])) + version_info.append(f'rc{patch_version[1]}') + return tuple(version_info) + + +version_info = parse_version_info(__version__) diff --git a/models/models.py b/models/models.py new file mode 100644 index 0000000000000000000000000000000000000000..5380c218397867274ea6910126dbfdf3e352306c --- /dev/null +++ b/models/models.py @@ -0,0 +1,23 @@ +import copy + + +models = {} + + +def register(name): + def decorator(cls): + models[name] = cls + return cls + return decorator + + +def make(model_spec, args=None, load_sd=False): + if args is not None: + model_args = copy.deepcopy(model_spec['args']) + model_args.update(args) + else: + model_args = model_spec['args'] + model = models[model_spec['name']](**model_args) + if load_sd: + model.load_state_dict(model_spec['sd']) + return model diff --git a/models/predictor.py b/models/predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..1147b8122799a772d1b1c58c06bda145c5e54469 --- /dev/null +++ b/models/predictor.py @@ -0,0 +1,269 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from models.sam import SAM + +from typing import Optional, Tuple + +from .util.transforms import ResizeLongestSide + + +class SamPredictor: + def __init__( + self, + sam_model: SAM, + ) -> None: + """ + Uses SAM to calculate the image embedding for an image, and then + allow repeated, efficient mask prediction given prompts. + + Arguments: + sam_model (Sam): The model to use for mask prediction. + """ + super().__init__() + self.model = sam_model + self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) + self.reset_image() + + def set_image( + self, + image: np.ndarray, + image_format: str = "RGB", + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. + + Arguments: + image (np.ndarray): The image for calculating masks. Expects an + image in HWC uint8 format, with pixel values in [0, 255]. + image_format (str): The color format of the image, in ['RGB', 'BGR']. + """ + assert image_format in [ + "RGB", + "BGR", + ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." + if image_format != self.model.image_format: + image = image[..., ::-1] + + # Transform the image to the form expected by the model + input_image = self.transform.apply_image(image) + input_image_torch = torch.as_tensor(input_image, device=self.device) + input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] + + self.set_torch_image(input_image_torch, image.shape[:2]) + + @torch.no_grad() + def set_torch_image( + self, + transformed_image: torch.Tensor, + original_image_size: Tuple[int, ...], + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. Expects the input + image to be already transformed to the format expected by the model. + + Arguments: + transformed_image (torch.Tensor): The input image, with shape + 1x3xHxW, which has been transformed with ResizeLongestSide. + original_image_size (tuple(int, int)): The size of the image + before transformation, in (H, W) format. + """ + assert ( + len(transformed_image.shape) == 4 + and transformed_image.shape[1] == 3 + and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size + ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." + self.reset_image() + + self.original_size = original_image_size + self.input_size = tuple(transformed_image.shape[-2:]) + input_image = self.model.preprocess(transformed_image) + self.features = self.model.image_encoder(input_image) + self.is_image_set = True + + def predict( + self, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Predict masks for the given input prompts, using the currently set image. + + Arguments: + point_coords (np.ndarray or None): A Nx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (np.ndarray or None): A length N array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A length 4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form 1xHxW, where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + # Transform input prompts + coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = self.transform.apply_coords(point_coords, self.original_size) + coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) + labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) + coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] #变成 BN2 AND B N + if box is not None: + box = self.transform.apply_boxes(box, self.original_size) + box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) + box_torch = box_torch[None, :] + if mask_input is not None: + mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) + mask_input_torch = mask_input_torch[None, :, :, :] + + masks, iou_predictions, low_res_masks = self.predict_torch( + coords_torch, + labels_torch, + box_torch, + mask_input_torch, + multimask_output, + return_logits=return_logits, + ) + + masks_np = masks[0].detach().cpu().numpy() + iou_predictions_np = iou_predictions[0].detach().cpu().numpy() + low_res_masks_np = low_res_masks[0].detach().cpu().numpy() + return masks_np, iou_predictions_np, low_res_masks_np + + @torch.no_grad() + def predict_torch( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using ResizeLongestSide. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + boxes (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + if point_coords is not None: + points = (point_coords, point_labels) + else: + points = None + + # Embed prompts + sparse_embeddings, dense_embeddings = self.model.prompt_encoder( + points=points, + boxes=boxes, + masks=mask_input, + ) + + # Predict masks + low_res_masks, iou_predictions = self.model.mask_decoder( + image_embeddings=self.features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + + # Upscale the masks to the original image resolution + masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) + + if not return_logits: + masks = masks > self.model.mask_threshold + + return masks, iou_predictions, low_res_masks + + def get_image_embedding(self) -> torch.Tensor: + """ + Returns the image embeddings for the currently set image, with + shape 1xCxHxW, where C is the embedding dimension and (H,W) are + the embedding spatial dimension of SAM (typically C=256, H=W=64). + """ + if not self.is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) to generate an embedding." + ) + assert self.features is not None, "Features must exist if an image has been set." + return self.features + + @property + def device(self) -> torch.device: + return self.model.device + + def reset_image(self) -> None: + """Resets the currently set image.""" + self.is_image_set = False + self.features = None + self.orig_h = None + self.orig_w = None + self.input_h = None + self.input_w = None diff --git a/models/sam.py b/models/sam.py new file mode 100644 index 0000000000000000000000000000000000000000..92e00035e4caa2c1f490f07c45f1cc03142f15e5 --- /dev/null +++ b/models/sam.py @@ -0,0 +1,361 @@ +import logging +from functools import partial + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from models import register +from .mmseg.models.sam import ImageEncoderViT, MaskDecoder, TwoWayTransformer + +logger = logging.getLogger(__name__) +from .iou_loss import IOU +from typing import Any, Optional, Tuple + +from .mmseg.models.sam import PromptEncoder + +def init_weights(layer): + if type(layer) == nn.Conv2d: + nn.init.normal_(layer.weight, mean=0.0, std=0.02) + nn.init.constant_(layer.bias, 0.0) + elif type(layer) == nn.Linear: + nn.init.normal_(layer.weight, mean=0.0, std=0.02) + nn.init.constant_(layer.bias, 0.0) + elif type(layer) == nn.BatchNorm2d: + # print(layer) + nn.init.normal_(layer.weight, mean=1.0, std=0.02) + nn.init.constant_(layer.bias, 0.0) + +class BBCEWithLogitLoss(nn.Module): + ''' + Balanced BCEWithLogitLoss + ''' + def __init__(self): + super(BBCEWithLogitLoss, self).__init__() + + def forward(self, pred, gt): + eps = 1e-10 + count_pos = torch.sum(gt) + eps + count_neg = torch.sum(1. - gt) + ratio = count_neg / count_pos + w_neg = count_pos / (count_pos + count_neg) + + bce1 = nn.BCEWithLogitsLoss(pos_weight=ratio) + loss = w_neg * bce1(pred, gt) + + return loss + +def _iou_loss(pred, target): + print('*****&&&', pred.shape, target.shape) + pred = torch.sigmoid(pred) + inter = (pred * target).sum(dim=(2, 3)) + union = (pred + target).sum(dim=(2, 3)) - inter + iou = 1 - (inter / union) + + return iou.mean() + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: int) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size, size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + +@register('sam') +class SAM(nn.Module): + def __init__(self, inp_size=None, encoder_mode=None, loss=None): + super().__init__() + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.embed_dim = encoder_mode['embed_dim'] + self.image_encoder = ImageEncoderViT( + img_size=inp_size, + patch_size=encoder_mode['patch_size'], + in_chans=3, + embed_dim=encoder_mode['embed_dim'], + depth=encoder_mode['depth'], + num_heads=encoder_mode['num_heads'], + mlp_ratio=encoder_mode['mlp_ratio'], + out_chans=encoder_mode['out_chans'], + qkv_bias=encoder_mode['qkv_bias'], + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + use_rel_pos=encoder_mode['use_rel_pos'], + rel_pos_zero_init=True, + window_size=encoder_mode['window_size'], + global_attn_indexes=encoder_mode['global_attn_indexes'], + ) + self.prompt_embed_dim = encoder_mode['prompt_embed_dim']#256 + prompt_embed_dim = 256 + image_embedding_size = inp_size / 16 + self.prompt_encoder = PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(int(image_embedding_size), int(image_embedding_size)), + input_image_size=(inp_size, inp_size), + mask_in_chans=16, + ) + + + self.mask_decoder = MaskDecoder( + # num_multimask_outputs=3, + # num_multimask_outputs=15,#iasid + # num_multimask_outputs=5, + # num_multimask_outputs=25, + num_multimask_outputs=14, + # num_multimask_outputs=26, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=self.prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=self.prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ) + self.mask_decoder_diwu = MaskDecoder( + # num_multimask_outputs=3, + # num_multimask_outputs=15,#iasid + # num_multimask_outputs=5, + # num_multimask_outputs=25, + # num_multimask_outputs=12, + num_multimask_outputs=12, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=self.prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=self.prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ) + + if 'evp' in encoder_mode['name']: + for k, p in self.encoder.named_parameters(): + if "prompt" not in k and "mask_decoder" not in k and "prompt_encoder" not in k: + p.requires_grad = False + + + + self.loss_mode = loss + if self.loss_mode == 'bce': + self.criterionBCE = torch.nn.BCEWithLogitsLoss() + + elif self.loss_mode == 'bbce': + self.criterionBCE = BBCEWithLogitLoss() + + elif self.loss_mode == 'iou': + self.criterionBCE = torch.nn.BCEWithLogitsLoss() + self.criterionIOU = IOU() + + elif self.loss_mode == 'cr': + # self.criterionCR = torch.nn.CrossEntropyLoss(ignore_index=255, reduction='mean') + self.criterionCR = torch.nn.CrossEntropyLoss(ignore_index=25, reduction='mean') + # 鑳屾櫙绫讳笉鍙備笌璁$畻loss + self.criterionIOU = IOU() + + self.pe_layer = PositionEmbeddingRandom(encoder_mode['prompt_embed_dim'] // 2) + self.inp_size = inp_size + self.image_embedding_size = inp_size // encoder_mode['patch_size']#1024/16 + self.no_mask_embed = nn.Embedding(1, encoder_mode['prompt_embed_dim'])#256 + + def set_input(self, input, gt_mask): + self.input = input.to(self.device) + self.gt_mask = gt_mask.to(self.device) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + + def forward(self): + bs = 1 + + # Embed prompts + sparse_embeddings = torch.empty((bs, 0, self.prompt_embed_dim), device=self.input.device)#绌簍ensor + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size, self.image_embedding_size + ) + #鎻愬彇 image embedding + # print('-----input-----',self.input.shape) + self.features = self.image_encoder(self.input) #鏈€鍚庝竴灞傝緭鍑? # print('-----image emded-----', self.features.shape) + # Predict masks + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=self.features, + image_pe=self.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + # multimask_output=False, + multimask_output=True, + )#B*C+1*H*W + low_res_masks_2, iou_predictions_2 = self.mask_decoder_diwu( + image_embeddings=self.features, + image_pe=self.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + # multimask_output=False, + multimask_output=True, + )#B*C+1*H*W + # print('----before cat',low_res_masks.shape, low_res_masks_2.shape) + low_res_masks = torch.cat((low_res_masks, low_res_masks_2), 1) + # print('----behind cat',low_res_masks.shape) + # Upscale the masks to the original image resolution + masks = self.postprocess_masks(low_res_masks, self.inp_size, self.inp_size) + self.pred_mask = masks + + def infer(self, input): + bs = 1 + + # Embed prompts + sparse_embeddings = torch.empty((bs, 0, self.prompt_embed_dim), device=input.device) + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size, self.image_embedding_size + ) + + self.features = self.image_encoder(input) + + # Predict masks + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=self.features, + image_pe=self.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + # multimask_output=False, + multimask_output=True, + )#b*1*256*256 + low_res_masks_2, iou_predictions_2 = self.mask_decoder_diwu( + image_embeddings=self.features, + image_pe=self.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + # multimask_output=False, + multimask_output=True, + ) # B*C+1*H*W + # print('----before cat',low_res_masks.shape, low_res_masks_2.shape) + low_res_masks = torch.cat((low_res_masks, low_res_masks_2), 1) + + + # Upscale the masks to the original image resolution + #b*1*1024*1024 + masks = self.postprocess_masks(low_res_masks, self.inp_size, self.inp_size)#涓婇噰鏍疯嚦鍘熷浘澶у皬 + # masks = masks.sigmoid() + return masks + + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: Tuple[int, ...], + original_size: Tuple[int, ...], + ) -> torch.Tensor: + """ + Remove padding and upscale masks to the original image size. + + Arguments: + masks (torch.Tensor): Batched masks from the mask_decoder, + in BxCxHxW format. + input_size (tuple(int, int)): The size of the image input to the + model, in (H, W) format. Used to remove padding. + original_size (tuple(int, int)): The original size of the image + before resizing for input to the model, in (H, W) format. + + Returns: + (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) + is given by original_size. + """ + + masks = F.interpolate( + masks, + (self.image_encoder.img_size, self.image_encoder.img_size), + mode="bilinear", + align_corners=False, + ) + masks = masks[..., : input_size, : input_size] + masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) + return masks + + def backward_G(self): + """Calculate GAN and L1 loss for the generator""" + # self.loss_G = self.criterionBCE(self.pred_mask, self.gt_mask) + # if self.loss_mode == 'iou': + # self.loss_G += _iou_loss(self.pred_mask, self.gt_mask) + # print('^&&&*###',self.pred_mask.shape, self.gt_mask.shape) + # print(torch.unique(self.gt_mask)) + self.loss_G = self.criterionCR(self.pred_mask, self.gt_mask.squeeze(1).long()) + # if self.loss_mode == 'cr': + # self.loss_G += _iou_loss(self.pred_mask, self.gt_mask) + + # print('***selg gt masks',torch.unique(self.gt_mask)) + # print('####', self.loss_G) + self.loss_G.backward() + def _backward_(self, pred_mask, gt_mask): + self.loss_G = self.criterionCR(pred_mask, gt_mask.squeeze(1).long()) + self.loss_G.backward() + + def optimize_parameters(self): + self.forward() + self.optimizer.zero_grad() # set G's gradients to zero + self.backward_G() # calculate graidents for G + self.optimizer.step() # udpate G's weights + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.image_encoder.img_size - h + padw = self.image_encoder.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x + + def set_requires_grad(self, nets, requires_grad=False): + """Set requies_grad=Fasle for all the networks to avoid unnecessary computations + Parameters: + nets (network list) -- a list of networks + requires_grad (bool) -- whether the networks require gradients or not + """ + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad diff --git a/models/sam_single.py b/models/sam_single.py new file mode 100644 index 0000000000000000000000000000000000000000..0af88507cd093ebf94625a5660a269f2795af40c --- /dev/null +++ b/models/sam_single.py @@ -0,0 +1,364 @@ +import logging +from functools import partial + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from models import register +from .mmseg.models.sam import ImageEncoderViT, MaskDecoder, TwoWayTransformer + +logger = logging.getLogger(__name__) +from .iou_loss import IOU +from typing import Any, Optional, Tuple + +from .mmseg.models.sam import PromptEncoder + +def init_weights(layer): + if type(layer) == nn.Conv2d: + nn.init.normal_(layer.weight, mean=0.0, std=0.02) + nn.init.constant_(layer.bias, 0.0) + elif type(layer) == nn.Linear: + nn.init.normal_(layer.weight, mean=0.0, std=0.02) + nn.init.constant_(layer.bias, 0.0) + elif type(layer) == nn.BatchNorm2d: + # print(layer) + nn.init.normal_(layer.weight, mean=1.0, std=0.02) + nn.init.constant_(layer.bias, 0.0) + +class BBCEWithLogitLoss(nn.Module): + ''' + Balanced BCEWithLogitLoss + ''' + def __init__(self): + super(BBCEWithLogitLoss, self).__init__() + + def forward(self, pred, gt): + eps = 1e-10 + count_pos = torch.sum(gt) + eps + count_neg = torch.sum(1. - gt) + ratio = count_neg / count_pos + w_neg = count_pos / (count_pos + count_neg) + + bce1 = nn.BCEWithLogitsLoss(pos_weight=ratio) + loss = w_neg * bce1(pred, gt) + + return loss + +def _iou_loss(pred, target): + print('*****&&&', pred.shape, target.shape) + pred = torch.sigmoid(pred) + inter = (pred * target).sum(dim=(2, 3)) + union = (pred + target).sum(dim=(2, 3)) - inter + iou = 1 - (inter / union) + + return iou.mean() + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: int) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size, size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + +@register('sam_single') +class SAM(nn.Module): + def __init__(self, inp_size=None, encoder_mode=None, loss=None): + super().__init__() + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.embed_dim = encoder_mode['embed_dim'] + self.image_encoder = ImageEncoderViT( + img_size=inp_size, + patch_size=encoder_mode['patch_size'], + in_chans=3, + embed_dim=encoder_mode['embed_dim'], + depth=encoder_mode['depth'], + num_heads=encoder_mode['num_heads'], + mlp_ratio=encoder_mode['mlp_ratio'], + out_chans=encoder_mode['out_chans'], + qkv_bias=encoder_mode['qkv_bias'], + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + use_rel_pos=encoder_mode['use_rel_pos'], + rel_pos_zero_init=True, + window_size=encoder_mode['window_size'], + global_attn_indexes=encoder_mode['global_attn_indexes'], + ) + self.prompt_embed_dim = encoder_mode['prompt_embed_dim']#256 + prompt_embed_dim = 256 + image_embedding_size = inp_size / 16 + self.prompt_encoder = PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(int(image_embedding_size), int(image_embedding_size)), + input_image_size=(inp_size, inp_size), + mask_in_chans=16, + ) + + + self.mask_decoder = MaskDecoder( + # num_multimask_outputs=3, + # num_multimask_outputs=15,#iasid + # num_multimask_outputs=5, + # num_multimask_outputs=25, + # num_multimask_outputs=14, + # num_multimask_outputs=12, + # num_multimask_outputs=14, + num_multimask_outputs=26, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=self.prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=self.prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ) + # self.mask_decoder_diwu = MaskDecoder( + # # num_multimask_outputs=3, + # # num_multimask_outputs=15,#iasid + # # num_multimask_outputs=5, + # # num_multimask_outputs=25, + # num_multimask_outputs=12, + # # num_multimask_outputs=26, + # transformer=TwoWayTransformer( + # depth=2, + # embedding_dim=self.prompt_embed_dim, + # mlp_dim=2048, + # num_heads=8, + # ), + # transformer_dim=self.prompt_embed_dim, + # iou_head_depth=3, + # iou_head_hidden_dim=256, + # ) + + if 'evp' in encoder_mode['name']: + for k, p in self.encoder.named_parameters(): + if "prompt" not in k and "mask_decoder" not in k and "prompt_encoder" not in k: + p.requires_grad = False + + + + self.loss_mode = loss + if self.loss_mode == 'bce': + self.criterionBCE = torch.nn.BCEWithLogitsLoss() + + elif self.loss_mode == 'bbce': + self.criterionBCE = BBCEWithLogitLoss() + + elif self.loss_mode == 'iou': + self.criterionBCE = torch.nn.BCEWithLogitsLoss() + self.criterionIOU = IOU() + + elif self.loss_mode == 'cr': + # self.criterionCR = torch.nn.CrossEntropyLoss(ignore_index=255, reduction='mean') + self.criterionCR = torch.nn.CrossEntropyLoss(ignore_index=25, reduction='mean') + # 背景类不参与计算loss + self.criterionIOU = IOU() + + self.pe_layer = PositionEmbeddingRandom(encoder_mode['prompt_embed_dim'] // 2) + self.inp_size = inp_size + self.image_embedding_size = inp_size // encoder_mode['patch_size']#1024/16 + self.no_mask_embed = nn.Embedding(1, encoder_mode['prompt_embed_dim'])#256 + + def set_input(self, input, gt_mask): + self.input = input.to(self.device) + self.gt_mask = gt_mask.to(self.device) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + + def forward(self): + bs = 1 + + # Embed prompts + sparse_embeddings = torch.empty((bs, 0, self.prompt_embed_dim), device=self.input.device)#空tensor + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size, self.image_embedding_size + ) + #提取 image embedding + # print('-----input-----',self.input.shape) + self.features = self.image_encoder(self.input) #最后一层输出 + # print('-----image emded-----', self.features.shape) + # Predict masks + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=self.features, + image_pe=self.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + # multimask_output=False, + multimask_output=True, + )#B*C+1*H*W + # low_res_masks_2, iou_predictions_2 = self.mask_decoder_diwu( + # image_embeddings=self.features, + # image_pe=self.get_dense_pe(), + # sparse_prompt_embeddings=sparse_embeddings, + # dense_prompt_embeddings=dense_embeddings, + # # multimask_output=False, + # multimask_output=True, + # )#B*C+1*H*W + # print('----before cat',low_res_masks.shape, low_res_masks_2.shape) + # low_res_masks = torch.cat((low_res_masks, low_res_masks_2), 1) + # print('----beshind cat',low_res_masks.shape) + # Upscale the masks to the original image resolution + masks = self.postprocess_masks(low_res_masks, self.inp_size, self.inp_size) + self.pred_mask = masks + + def infer(self, input): + bs = 1 + + # Embed prompts + sparse_embeddings = torch.empty((bs, 0, self.prompt_embed_dim), device=input.device) + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size, self.image_embedding_size + ) + + self.features = self.image_encoder(input) + + # Predict masks + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=self.features, + image_pe=self.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + # multimask_output=False, + multimask_output=True, + )#b*1*256*256 + # low_res_masks_2, iou_predictions_2 = self.mask_decoder_diwu( + # image_embeddings=self.features, + # image_pe=self.get_dense_pe(), + # sparse_prompt_embeddings=sparse_embeddings, + # dense_prompt_embeddings=dense_embeddings, + # # multimask_output=False, + # multimask_output=True, + # ) # B*C+1*H*W + # print('----before cat',low_res_masks.shape, low_res_masks_2.shape) + # low_res_masks = torch.cat((low_res_masks, low_res_masks_2), 1) + + + # Upscale the masks to the original image resolution + #b*1*1024*1024 + masks = self.postprocess_masks(low_res_masks, self.inp_size, self.inp_size)#上采样至原图大小 + # masks = masks.sigmoid() + return masks + + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: Tuple[int, ...], + original_size: Tuple[int, ...], + ) -> torch.Tensor: + """ + Remove padding and upscale masks to the original image size. + + Arguments: + masks (torch.Tensor): Batched masks from the mask_decoder, + in BxCxHxW format. + input_size (tuple(int, int)): The size of the image input to the + model, in (H, W) format. Used to remove padding. + original_size (tuple(int, int)): The original size of the image + before resizing for input to the model, in (H, W) format. + + Returns: + (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) + is given by original_size. + """ + + masks = F.interpolate( + masks, + (self.image_encoder.img_size, self.image_encoder.img_size), + mode="bilinear", + align_corners=False, + ) + masks = masks[..., : input_size, : input_size] + masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) + return masks + + def backward_G(self): + """Calculate GAN and L1 loss for the generator""" + # self.loss_G = self.criterionBCE(self.pred_mask, self.gt_mask) + # if self.loss_mode == 'iou': + # self.loss_G += _iou_loss(self.pred_mask, self.gt_mask) + # print('^&&&*###',self.pred_mask.shape, self.gt_mask.shape) + # print(torch.unique(self.gt_mask)) + self.loss_G = self.criterionCR(self.pred_mask, self.gt_mask.squeeze(1).long()) + # if self.loss_mode == 'cr': + # self.loss_G += _iou_loss(self.pred_mask, self.gt_mask) + + # print('***selg gt masks',torch.unique(self.gt_mask)) + # print('####', self.loss_G) + self.loss_G.backward() + def _backward_(self, pred_mask, gt_mask): + self.loss_G = self.criterionCR(pred_mask, gt_mask.squeeze(1).long()) + self.loss_G.backward() + + def optimize_parameters(self): + self.forward() + self.optimizer.zero_grad() # set G's gradients to zero + self.backward_G() # calculate graidents for G + self.optimizer.step() # udpate G's weights + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.image_encoder.img_size - h + padw = self.image_encoder.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x + + def set_requires_grad(self, nets, requires_grad=False): + """Set requies_grad=Fasle for all the networks to avoid unnecessary computations + Parameters: + nets (network list) -- a list of networks + requires_grad (bool) -- whether the networks require gradients or not + """ + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad diff --git a/models/util/__init__.py b/models/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/models/util/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/models/util/transforms.py b/models/util/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..c08ba1e3db751f3a5483a003be38c69c2cf2df85 --- /dev/null +++ b/models/util/transforms.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch.nn import functional as F +from torchvision.transforms.functional import resize, to_pil_image # type: ignore + +from copy import deepcopy +from typing import Tuple + + +class ResizeLongestSide: + """ + Resizes images to the longest side 'target_length', as well as provides + methods for resizing coordinates and boxes. Provides methods for + transforming both numpy array and batched torch tensors. + """ + + def __init__(self, target_length: int) -> None: + self.target_length = target_length + + def apply_image(self, image: np.ndarray) -> np.ndarray: + """ + Expects a numpy array with shape HxWxC in uint8 format. + """ + target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) + return np.array(resize(to_pil_image(image), target_size)) + + def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array of length 2 in the final dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).astype(float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array shape Bx4. Requires the original image size + in (H, W) format. + """ + boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: + """ + Expects batched images with shape BxCxHxW and float format. This + transformation may not exactly match apply_image. apply_image is + the transformation expected by the model. + """ + # Expects an image in BCHW format. May not exactly match apply_image. + target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) + return F.interpolate( + image, target_size, mode="bilinear", align_corners=False, antialias=True + ) + + def apply_coords_torch( + self, coords: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).to(torch.float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes_torch( + self, boxes: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with shape Bx4. Requires the original image + size in (H, W) format. + """ + boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + @staticmethod + def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: + """ + Compute the output size given input size and target long side length. + """ + scale = long_side_length * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww) diff --git a/models/utils_prompt.py b/models/utils_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..992d23328e7188d87ee4827305f595d462258dfa --- /dev/null +++ b/models/utils_prompt.py @@ -0,0 +1,132 @@ +# coding:utf-8 +import os + +import numpy as np +import cv2 +from typing import Optional +import torch + +# from models.transforms import ResizeLongestSide +# from .transforms import ResizeLongestSide +from torchvision import transforms + +def get_prompt_inp_scatter(scatter_file_): + + scatter_mask = cv2.imread(scatter_file_, cv2.IMREAD_UNCHANGED) + + return scatter_mask + +def pre_scatter_prompt(scatter, filp, device): + if filp == True: + scatter = cv2.flip(scatter, 1) + + img_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + scatter_torch = img_transform(scatter) + scatter_torch = scatter_torch.to(device) + return scatter_torch + +def get_prompt_inp(txt_file_, filp): + + f = open(txt_file_) + lines = f.readlines() + points = [] + labels = [] + boxes = [] + masks = [] + for line in lines: + x_1, y_1, x_2, y_2, x_3, y_3, x_4, y_4, classname, _ = line.split(' ') + # print(x_1, y_1, x_2, y_2, x_3, y_3, x_4, y_4, classname, _) + x_1, y_1, x_2, y_2, x_3, y_3, x_4, y_4 = float(x_1), float(y_1), \ + float(x_2), float(y_2), \ + float(x_3), float(y_3), \ + float(x_4), float(y_4) + xmin = min(x_1, x_2, x_3, x_4) + xmax = max(x_1, x_2, x_3, x_4) + ymin = min(y_1, y_2, y_3, y_4) + ymax = max(y_1, y_2, y_3, y_4) + if filp: + xmin = 1024.0 - xmin + xmax = 1024.0 - xmax + x_center = (xmin + xmax)/2 + y_center = (ymin + ymax)/2 + point = [x_center, y_center] + box = [[xmin, ymin], [xmax, ymax]] + # box = [xmin, ymin, xmax, ymax] + mask = [] + points.append(point) + labels.append(classname) + boxes.append(box) + masks.append(mask) + # boxes = boxes[:1] + # return points, labels, boxes, masks + return points, labels, boxes, None + +def pre_prompt(points=None, boxes=None, masks=None, device=None): + + points_torch = points + if points != None: + # points = points/16.0 + points_torch = torch.as_tensor(points, dtype=torch.float, device=device) + points_torch = points_torch/16.0 + boxes_torch = boxes + if boxes != None: + # boxes = boxes/16.0 + boxes_torch = torch.as_tensor(boxes, dtype=torch.float, device=device) + boxes_torch = boxes_torch/16.0 + # for box in boxes: + # left_top, bottom_right = box + masks_torch = masks + if masks != None: + masks_torch = torch.as_tensor(masks, dtype=torch.float, device=device) + + return points_torch, boxes_torch, masks_torch + + + +# def pre_prompt( +# point_coords: Optional[np.ndarray] = None, +# point_labels: Optional[np.ndarray] = None, +# box: Optional[np.ndarray] = None, +# mask_input: Optional[np.ndarray] = None, +# device=None, +# original_size = [1024, 1024] +# ): +# +# transform = ResizeLongestSide(1024) +# # Transform input prompts +# coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None +# if point_coords is not None: +# assert ( +# point_labels is not None +# ), "point_labels must be supplied if point_coords is supplied." +# point_coords = transform.apply_coords(point_coords, original_size) +# coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=device) +# labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=device) +# coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] +# if box is not None: +# box = transform.apply_boxes(box, original_size) +# box_torch = torch.as_tensor(box, dtype=torch.float, device=device) +# box_torch = box_torch[None, :] +# if mask_input is not None: +# mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=device) +# mask_input_torch = mask_input_torch[None, :, :, :] +# +# return coords_torch, labels_torch, box_torch, mask_input_torch + +if __name__ == '__main__': + txt_dir = './ISAID/train/trainprompt/sub_labelTxt/' + txt_list = os.listdir(txt_dir) + txt_file_0 = txt_dir + txt_list[0] + points, labels, boxes, masks = get_prompt_inp(txt_file_0) + print(points) + print(labels) + print(boxes) + # boxes = boxes / 16.0 + boxes_torch = torch.as_tensor(boxes, dtype=torch.float) + boxes_torch = boxes_torch/16.0 + print(boxes_torch, boxes_torch.shape) + diff --git a/until/__init__.py b/until/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fabcb72da041e6428e72653899a79bfb955ae636 --- /dev/null +++ b/until/__init__.py @@ -0,0 +1 @@ +from dota_utils import GetFileFromThisRootDir \ No newline at end of file diff --git a/until/dota_utils.py b/until/dota_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fd64b59288ae1bbdc1a5793ac679249634ee7024 --- /dev/null +++ b/until/dota_utils.py @@ -0,0 +1,321 @@ +import sys +import codecs +import numpy as np +import shapely.geometry as shgeo +import os +import re +import math +# import polyiou +""" + some basic functions which are useful for process DOTA data +""" + +wordname_15 = ['plane', 'baseball-diamond', 'bridge', 'ground-track-field', 'small-vehicle', 'large-vehicle', 'ship', 'tennis-court', + 'basketball-court', 'storage-tank', 'soccer-ball-field', 'roundabout', 'harbor', 'swimming-pool', 'helicopter'] + +def custombasename(fullname): + return os.path.basename(os.path.splitext(fullname)[0]) + +def GetFileFromThisRootDir(dir,ext = None): + allfiles = [] + needExtFilter = (ext != None) + for root,dirs,files in os.walk(dir): + for filespath in files: + filepath = os.path.join(root, filespath) + extension = os.path.splitext(filepath)[1][1:] + if needExtFilter and extension in ext: + allfiles.append(filepath) + elif not needExtFilter: + allfiles.append(filepath) + return allfiles + +def TuplePoly2Poly(poly): + outpoly = [poly[0][0], poly[0][1], + poly[1][0], poly[1][1], + poly[2][0], poly[2][1], + poly[3][0], poly[3][1] + ] + return outpoly + +def parse_dota_poly_refactor(filename, code): + """ + parse the dota ground truth in the format: + [(x1, y1), (x2, y2), (x3, y3), (x4, y4)] + """ + objects = [] + #print('filename:', filename) + f = [] + if (sys.version_info >= (3, 5)): + fd = open(filename, 'r') + f = fd + elif (sys.version_info >= 2.7): + fd = codecs.open(filename, 'r', code) + f = fd + # count = 0 + while True: + line = f.readline() + # count = count + 1 + # if count < 2: + # continue + if line: + splitlines = line.strip().split(' ') + object_struct = {} + ### clear the wrong name after check all the data + #if (len(splitlines) >= 9) and (splitlines[8] in classname): + if (len(splitlines) < 9): + continue + if (len(splitlines) >= 9): + object_struct['name'] = splitlines[8] + if (len(splitlines) == 9): + object_struct['difficult'] = '0' + elif (len(splitlines) >= 10): + # if splitlines[9] == '1': + # if (splitlines[9] == 'tr'): + # object_struct['difficult'] = '1' + # else: + object_struct['difficult'] = splitlines[9] + # else: + # object_struct['difficult'] = 0 + object_struct['poly'] = [(float(splitlines[0]), float(splitlines[1])), + (float(splitlines[2]), float(splitlines[3])), + (float(splitlines[4]), float(splitlines[5])), + (float(splitlines[6]), float(splitlines[7])) + ] + gtpoly = shgeo.Polygon(object_struct['poly']) + object_struct['area'] = gtpoly.area + # poly = list(map(lambda x:np.array(x), object_struct['poly'])) + # object_struct['long-axis'] = max(distance(poly[0], poly[1]), distance(poly[1], poly[2])) + # object_struct['short-axis'] = min(distance(poly[0], poly[1]), distance(poly[1], poly[2])) + # if (object_struct['long-axis'] < 15): + # object_struct['difficult'] = '1' + # global small_count + # small_count = small_count + 1 + objects.append(object_struct) + else: + break + return objects + +def parse_dota_poly(filename): + # 读进一个八个点坐标+类别+不知道啥的txt + # 最后生成一个列表,列表内是字典,每个字典表示一个物体,字典的键有:物体类别(name),难例(difficult),坐标(poly),面积(area) + # 这里的poly是dota8个点的坐标,注意更改,VOC最后要的是左上(小)右下坐标,coco要的是左上坐标和wh + # 注,我生成VOC可以不要area + """ + parse the dota ground truth in the format: + [(x1, y1), (x2, y2), (x3, y3), (x4, y4)] + """ + objects = [] + # print('filename:', filename) + f = [] + if (sys.version_info >= (3, 5)): + fd = open(filename, 'r') + f = fd + elif (sys.version_info >= 2.7): + fd = codecs.open(filename, 'r') + f = fd + # count = 0 + while True: + line = f.readline() + # count = count + 1 + # if count < 2: + # continue + if line: + splitlines = line.strip().split(' ') + object_struct = {} + ### clear the wrong name after check all the data + #if (len(splitlines) >= 9) and (splitlines[8] in classname): + if (len(splitlines) < 9): + continue + if (len(splitlines) >= 9): + object_struct['name'] = splitlines[8] + if (len(splitlines) == 9): + object_struct['difficult'] = '0' + elif (len(splitlines) >= 10): + # if splitlines[9] == '1': + # if (splitlines[9] == 'tr'): + # object_struct['difficult'] = '1' + # else: + object_struct['difficult'] = splitlines[9] + # else: + # object_struct['difficult'] = 0 + object_struct['poly'] = [(float(splitlines[0]), float(splitlines[1])), + (float(splitlines[2]), float(splitlines[3])), + (float(splitlines[4]), float(splitlines[5])), + (float(splitlines[6]), float(splitlines[7])) + ] + gtpoly = shgeo.Polygon(object_struct['poly']) + object_struct['area'] = gtpoly.area + # poly = list(map(lambda x:np.array(x), object_struct['poly'])) + # object_struct['long-axis'] = max(distance(poly[0], poly[1]), distance(poly[1], poly[2])) + # object_struct['short-axis'] = min(distance(poly[0], poly[1]), distance(poly[1], poly[2])) + # if (object_struct['long-axis'] < 15): + # object_struct['difficult'] = '1' + # global small_count + # small_count = small_count + 1 + objects.append(object_struct) + else: + break + return objects + +def parse_dota_poly2(filename): + """ + parse the dota ground truth in the format: + [x1, y1, x2, y2, x3, y3, x4, y4] + """ + objects = parse_dota_poly(filename) + for obj in objects: + obj['poly'] = TuplePoly2Poly(obj['poly']) # 把tuple转成列表形式[x1, y1, x2, y2, x3, y3, x4, y4],所以为啥最开始要搞成tuple呢 + obj['poly'] = list(map(int, obj['poly'])) # dota把所有点都搞成整数,我应该不需要! + return objects + +def parse_dota_rec(filename): + """ + parse the dota ground truth in the bounding box format: + "xmin, ymin, xmax, ymax" + """ + objects = parse_dota_poly(filename) + for obj in objects: + poly = obj['poly'] + bbox = dots4ToRec4(poly) + obj['bndbox'] = bbox + return objects +## bounding box transfer for varies format + +def dots4ToRec4(poly): + xmin, xmax, ymin, ymax = min(poly[0][0], min(poly[1][0], min(poly[2][0], poly[3][0]))), \ + max(poly[0][0], max(poly[1][0], max(poly[2][0], poly[3][0]))), \ + min(poly[0][1], min(poly[1][1], min(poly[2][1], poly[3][1]))), \ + max(poly[0][1], max(poly[1][1], max(poly[2][1], poly[3][1]))) + return xmin, ymin, xmax, ymax +def dots4ToRec8(poly): + xmin, ymin, xmax, ymax = dots4ToRec4(poly) + return xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax + #return dots2ToRec8(dots4ToRec4(poly)) +def dots2ToRec8(rec): + xmin, ymin, xmax, ymax = rec[0], rec[1], rec[2], rec[3] + return xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax + +def groundtruth2Task1(srcpath, dstpath): + filelist = GetFileFromThisRootDir(srcpath) + # names = [custombasename(x.strip())for x in filelist] + filedict = {} + for cls in wordname_15: + fd = open(os.path.join(dstpath, 'Task1_') + cls + r'.txt', 'w') + filedict[cls] = fd + for filepath in filelist: + objects = parse_dota_poly2(filepath) + + subname = custombasename(filepath) + pattern2 = re.compile(r'__([\d+\.]+)__\d+___') + rate = re.findall(pattern2, subname)[0] + + for obj in objects: + category = obj['name'] + difficult = obj['difficult'] + poly = obj['poly'] + if difficult == '2': + continue + if rate == '0.5': + outline = custombasename(filepath) + ' ' + '1' + ' ' + ' '.join(map(str, poly)) + elif rate == '1': + outline = custombasename(filepath) + ' ' + '0.8' + ' ' + ' '.join(map(str, poly)) + elif rate == '2': + outline = custombasename(filepath) + ' ' + '0.6' + ' ' + ' '.join(map(str, poly)) + + filedict[category].write(outline + '\n') + +def Task2groundtruth_poly(srcpath, dstpath): + thresh = 0.1 + filedict = {} + Tasklist = GetFileFromThisRootDir(srcpath, '.txt') + + for Taskfile in Tasklist: + idname = custombasename(Taskfile).split('_')[-1] + # idname = datamap_inverse[idname] + f = open(Taskfile, 'r') + lines = f.readlines() + for line in lines: + if len(line) == 0: + continue + # print('line:', line) + splitline = line.strip().split(' ') + filename = splitline[0] + confidence = splitline[1] + bbox = splitline[2:] + if float(confidence) > thresh: + if filename not in filedict: + # filedict[filename] = codecs.open(os.path.join(dstpath, filename + '.txt'), 'w', 'utf_16') + filedict[filename] = codecs.open(os.path.join(dstpath, filename + '.txt'), 'w') + # poly = util.dots2ToRec8(bbox) + poly = bbox + # filedict[filename].write(' '.join(poly) + ' ' + idname + '_' + str(round(float(confidence), 2)) + '\n') + # print('idname:', idname) + + # filedict[filename].write(' '.join(poly) + ' ' + idname + '_' + str(round(float(confidence), 2)) + '\n') + + filedict[filename].write(' '.join(poly) + ' ' + idname + '\n') + + +def polygonToRotRectangle(bbox): + """ + :param bbox: The polygon stored in format [x1, y1, x2, y2, x3, y3, x4, y4] + :return: Rotated Rectangle in format [cx, cy, w, h, theta] + """ + bbox = np.array(bbox,dtype=np.float32) + bbox = np.reshape(bbox,newshape=(2,4),order='F') + angle = math.atan2(-(bbox[0,1]-bbox[0,0]),bbox[1,1]-bbox[1,0]) + + center = [[0],[0]] + + for i in range(4): + center[0] += bbox[0,i] + center[1] += bbox[1,i] + + center = np.array(center,dtype=np.float32)/4.0 + + R = np.array([[math.cos(angle), -math.sin(angle)], [math.sin(angle), math.cos(angle)]], dtype=np.float32) + + normalized = np.matmul(R.transpose(),bbox-center) + + xmin = np.min(normalized[0,:]) + xmax = np.max(normalized[0,:]) + ymin = np.min(normalized[1,:]) + ymax = np.max(normalized[1,:]) + + w = xmax - xmin + 1 + h = ymax - ymin + 1 + + return [float(center[0]),float(center[1]),w,h,angle] + +def cal_line_length(point1, point2): + return math.sqrt( math.pow(point1[0] - point2[0], 2) + math.pow(point1[1] - point2[1], 2)) + +def get_best_begin_point(coordinate): + x1 = coordinate[0][0] + y1 = coordinate[0][1] + x2 = coordinate[1][0] + y2 = coordinate[1][1] + x3 = coordinate[2][0] + y3 = coordinate[2][1] + x4 = coordinate[3][0] + y4 = coordinate[3][1] + xmin = min(x1, x2, x3, x4) + ymin = min(y1, y2, y3, y4) + xmax = max(x1, x2, x3, x4) + ymax = max(y1, y2, y3, y4) + combinate = [[[x1, y1], [x2, y2], [x3, y3], [x4, y4]], [[x2, y2], [x3, y3], [x4, y4], [x1, y1]], + [[x3, y3], [x4, y4], [x1, y1], [x2, y2]], [[x4, y4], [x1, y1], [x2, y2], [x3, y3]]] + dst_coordinate = [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]] + force = 100000000.0 + force_flag = 0 + for i in range(4): + temp_force = cal_line_length(combinate[i][0], dst_coordinate[0]) + cal_line_length(combinate[i][1], + dst_coordinate[ + 1]) + cal_line_length( + combinate[i][2], dst_coordinate[2]) + cal_line_length(combinate[i][3], dst_coordinate[3]) + if temp_force < force: + force = temp_force + force_flag = i + if force_flag != 0: + print("choose one direction!") + return combinate[force_flag]