| |
| |
| |
| |
| |
| |
|
|
| import os |
| import json |
| import shutil |
|
|
| from torchvision import datasets, transforms |
|
|
| from timm.data import create_transform |
| from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
|
|
| import numpy as np |
| from PIL import Image |
| import random |
| import torch |
| from torch.utils.data import DataLoader, Dataset, ConcatDataset |
| from torchvision import transforms |
| from torch.nn import functional as F |
|
|
|
|
| class collate_fn_crfrp: |
| def __init__(self, input_size=224, patch_size=16, mask_ratio=0.75): |
| self.img_size = input_size |
| self.patch_size = patch_size |
| self.num_patches_axis = input_size // patch_size |
| self.num_patches = (input_size // patch_size) ** 2 |
| self.mask_ratio = mask_ratio |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| self.facial_region_group = [ |
| [2, 3], |
| [4, 5], |
| [6], |
| [7, 8, 9], |
| [10, 1, 0], |
| [10], |
| [1], |
| [0] |
| ] |
|
|
| def __call__(self, samples): |
| image, img_mask, facial_region_mask, random_specific_facial_region \ |
| = self.CRFR_P_masking(samples, specified_facial_region=None) |
|
|
| return {'image': image, 'img_mask': img_mask, 'specific_facial_region_mask': facial_region_mask} |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def CRFR_P_masking(self, samples, specified_facial_region=None): |
| image = torch.stack([sample['image'] for sample in samples]) |
| parsing_map = torch.stack([sample['parsing_map'] for sample in samples]) |
| parsing_map = parsing_map.squeeze(1) |
|
|
| |
| facial_region_mask = torch.zeros(parsing_map.size(0), self.num_patches_axis, self.num_patches_axis, |
| dtype=torch.float32) |
| facial_region_mask, random_specific_facial_region \ |
| = self.masking_all_patches_in_random_specific_facial_region(parsing_map, facial_region_mask) |
| |
|
|
| img_mask, facial_region_mask \ |
| = self.variable_proportional_masking(parsing_map, facial_region_mask, random_specific_facial_region) |
| |
|
|
| del parsing_map |
| return image, img_mask, facial_region_mask, random_specific_facial_region |
|
|
| def masking_all_patches_in_random_specific_facial_region(self, parsing_map, facial_region_mask, |
| |
| ): |
| |
| |
| |
| |
| random_specific_facial_region = random.choice(self.facial_region_group[:-2]) |
| if random_specific_facial_region == [10, 1, 0]: |
| |
| patch_hair_bg = F.max_pool2d(((parsing_map == 10) + (parsing_map == 0)).float(), |
| kernel_size=self.patch_size) |
| |
| patch_skin = F.max_pool2d((parsing_map == 1).float(), kernel_size=self.patch_size) |
| |
| facial_region_mask = (patch_hair_bg.bool() & patch_skin.bool()).float() |
| else: |
| for facial_region_index in random_specific_facial_region: |
| facial_region_mask = torch.maximum(facial_region_mask, |
| F.max_pool2d((parsing_map == facial_region_index).float(), |
| kernel_size=self.patch_size)) |
|
|
| return facial_region_mask.view(parsing_map.size(0), -1), random_specific_facial_region |
|
|
| def variable_proportional_masking(self, parsing_map, facial_region_mask, random_specific_facial_region): |
| img_mask = facial_region_mask.clone() |
|
|
| |
| other_facial_region_group = [region for region in self.facial_region_group if |
| region != random_specific_facial_region] |
| |
| for i in range(facial_region_mask.size(0)): |
| num_mask_to_change = (self.mask_ratio * self.num_patches - facial_region_mask[i].sum(dim=-1)).int() |
| |
| mask_change_to = torch.clamp(num_mask_to_change, 0, 1).item() |
|
|
| if mask_change_to == 1: |
| |
| mask_ratio_other_fr = ( |
| num_mask_to_change / (self.num_patches - facial_region_mask[i].sum(dim=-1))) |
|
|
| masked_patches = facial_region_mask[i].clone() |
| for other_fr in other_facial_region_group: |
| to_mask_patches = torch.zeros(1, self.num_patches_axis, self.num_patches_axis, |
| dtype=torch.float32) |
| if other_fr == [10, 1, 0]: |
| patch_hair_bg = F.max_pool2d( |
| ((parsing_map[i].unsqueeze(0) == 10) + (parsing_map[i].unsqueeze(0) == 0)).float(), |
| kernel_size=self.patch_size) |
| patch_skin = F.max_pool2d((parsing_map[i].unsqueeze(0) == 1).float(), |
| kernel_size=self.patch_size) |
| |
| to_mask_patches = (patch_hair_bg.bool() & patch_skin.bool()).float() |
| else: |
| for facial_region_index in other_fr: |
| to_mask_patches = torch.maximum(to_mask_patches, |
| F.max_pool2d((parsing_map[i].unsqueeze( |
| 0) == facial_region_index).float(), |
| kernel_size=self.patch_size)) |
|
|
| |
| to_mask_patches = (to_mask_patches.view(-1) - masked_patches) > 0 |
| select_indices = to_mask_patches.nonzero(as_tuple=False).view(-1) |
| change_indices = torch.randperm(len(select_indices))[ |
| :torch.round(to_mask_patches.sum() * mask_ratio_other_fr).int()] |
| img_mask[i, select_indices[change_indices]] = mask_change_to |
| |
| masked_patches = masked_patches + to_mask_patches.float() |
|
|
| |
| num_mask_to_change = (self.mask_ratio * self.num_patches - img_mask[i].sum(dim=-1)).int() |
| |
| mask_change_to = torch.clamp(num_mask_to_change, 0, 1).item() |
| |
| select_indices = ((img_mask[i] + facial_region_mask[i]) == (1 - mask_change_to)).nonzero( |
| as_tuple=False).view(-1) |
| change_indices = torch.randperm(len(select_indices))[:torch.abs(num_mask_to_change)] |
| img_mask[i, select_indices[change_indices]] = mask_change_to |
|
|
| else: |
| |
| |
| select_indices = (facial_region_mask[i] == (1 - mask_change_to)).nonzero(as_tuple=False).view(-1) |
| change_indices = torch.randperm(len(select_indices))[:torch.abs(num_mask_to_change)] |
| img_mask[i, select_indices[change_indices]] = mask_change_to |
| facial_region_mask[i] = img_mask[i] |
|
|
| return img_mask, facial_region_mask |
|
|
|
|
| class FaceParsingDataset(Dataset): |
| def __init__(self, root, transform=None): |
| self.root_dir = root |
| self.transform = transform |
| self.image_folder = os.path.join(root, 'images') |
| self.parsing_map_folder = os.path.join(root, 'parsing_maps') |
| self.image_names = os.listdir(self.image_folder) |
|
|
| def __len__(self): |
| return len(self.image_names) |
|
|
| def __getitem__(self, idx): |
| img_name = os.path.join(self.image_folder, self.image_names[idx]) |
| parsing_map_name = os.path.join(self.parsing_map_folder, self.image_names[idx].replace('.png', '.npy')) |
|
|
| image = Image.open(img_name).convert("RGB") |
| parsing_map_np = np.load(parsing_map_name) |
|
|
| if self.transform: |
| image = self.transform(image) |
|
|
| |
| parsing_map = torch.from_numpy(parsing_map_np) |
| del parsing_map_np |
|
|
| return {'image': image, 'parsing_map': parsing_map} |
|
|
|
|
| class TestImageFolder(datasets.ImageFolder): |
| def __init__(self, root, transform=None, target_transform=None): |
| super(TestImageFolder, self).__init__(root, transform, target_transform) |
|
|
| def __getitem__(self, index): |
| |
| original_tuple = super(TestImageFolder, self).__getitem__(index) |
|
|
| |
| video_name = self.imgs[index][0].split('/')[-1].split('_frame_')[0] |
|
|
| |
| extended_tuple = (original_tuple + (video_name,)) |
|
|
| return extended_tuple |
|
|
|
|
| def get_mean_std(args): |
| print('dataset_paths:', args.data_path) |
| transform = transforms.Compose([transforms.ToTensor(), |
| transforms.Resize((args.input_size, args.input_size), |
| interpolation=transforms.InterpolationMode.BICUBIC)]) |
|
|
| if len(args.data_path) > 1: |
| pretrain_datasets = [FaceParsingDataset(root=path, transform=transform) for path in args.data_path] |
| dataset_pretrain = ConcatDataset(pretrain_datasets) |
| else: |
| pretrain_datasets = args.data_path[0] |
| dataset_pretrain = FaceParsingDataset(root=pretrain_datasets, transform=transform) |
|
|
| print('Compute mean and variance for pretraining data.') |
| print('len(dataset_train): ', len(dataset_pretrain)) |
|
|
| loader = DataLoader( |
| dataset_pretrain, |
| batch_size=args.batch_size, |
| num_workers=args.num_workers, |
| pin_memory=args.pin_mem, |
| drop_last=True, |
| ) |
|
|
| channels_sum, channels_squared_sum, num_batches = 0, 0, 0 |
| for sample in loader: |
| data = sample['image'] |
| channels_sum += torch.mean(data, dim=[0, 2, 3]) |
| channels_squared_sum += torch.mean(data ** 2, dim=[0, 2, 3]) |
| num_batches += 1 |
|
|
| mean = channels_sum / num_batches |
| std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5 |
|
|
| print(f'train dataset mean%: {mean.numpy()} std: %{std.numpy()} ') |
| del pretrain_datasets, dataset_pretrain, loader |
| return mean.numpy(), std.numpy() |
|
|
|
|
| def build_dataset(is_train, args): |
| transform = build_transform(is_train, args) |
| dataset = datasets.ImageFolder(args.data_path, transform=transform) |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| return dataset |
|
|
|
|
| def build_transform(is_train, args): |
| if args.normalize_from_IMN: |
| mean = IMAGENET_DEFAULT_MEAN |
| std = IMAGENET_DEFAULT_STD |
| |
| else: |
| if not os.path.exists(os.path.join(args.output_dir, "/pretrain_ds_mean_std.txt")) and not args.eval: |
| shutil.copyfile(os.path.dirname(args.finetune) + '/pretrain_ds_mean_std.txt', |
| os.path.join(args.output_dir) + '/pretrain_ds_mean_std.txt') |
| with open(os.path.join(os.path.dirname(args.resume)) + '/pretrain_ds_mean_std.txt' if args.eval |
| else os.path.join(args.output_dir) + '/pretrain_ds_mean_std.txt', 'r') as file: |
| ds_stat = json.loads(file.readline()) |
| mean = ds_stat['mean'] |
| std = ds_stat['std'] |
| |
|
|
| if args.apply_simple_augment: |
| if is_train: |
| |
| transform = create_transform( |
| input_size=args.input_size, |
| is_training=True, |
| color_jitter=args.color_jitter, |
| auto_augment=args.aa, |
| interpolation=transforms.InterpolationMode.BICUBIC, |
| re_prob=args.reprob, |
| re_mode=args.remode, |
| re_count=args.recount, |
| mean=mean, |
| std=std, |
| ) |
| return transform |
|
|
| |
| t = [] |
| if args.input_size <= 224: |
| crop_pct = 224 / 256 |
| else: |
| crop_pct = 1.0 |
| size = int(args.input_size / crop_pct) |
| t.append( |
| transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC), |
| |
| ) |
| t.append(transforms.CenterCrop(args.input_size)) |
|
|
| t.append(transforms.ToTensor()) |
| t.append(transforms.Normalize(mean, std)) |
| return transforms.Compose(t) |
|
|
| else: |
| t = [] |
| if args.input_size < 224: |
| crop_pct = input_size / 224 |
| else: |
| crop_pct = 1.0 |
| size = int(args.input_size / crop_pct) |
| t.append( |
| transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC), |
| |
| ) |
| |
| |
| |
| |
|
|
| t.append(transforms.ToTensor()) |
| t.append(transforms.Normalize(mean, std)) |
| return transforms.Compose(t) |
|
|