| """ |
| Copyright (c) Microsoft Corporation. |
| Licensed under the MIT license. |
| |
| """ |
|
|
|
|
| import os.path as op |
| import torch |
| import logging |
| import code |
| from custom_mesh_graphormer.utils.comm import get_world_size |
| from custom_mesh_graphormer.datasets.human_mesh_tsv import (MeshTSVDataset, MeshTSVYamlDataset) |
| from custom_mesh_graphormer.datasets.hand_mesh_tsv import (HandMeshTSVDataset, HandMeshTSVYamlDataset) |
|
|
|
|
| def build_dataset(yaml_file, args, is_train=True, scale_factor=1): |
| print(yaml_file) |
| if not op.isfile(yaml_file): |
| yaml_file = op.join(args.data_dir, yaml_file) |
| |
| assert op.isfile(yaml_file) |
| return MeshTSVYamlDataset(yaml_file, is_train, False, scale_factor) |
|
|
|
|
| class IterationBasedBatchSampler(torch.utils.data.sampler.BatchSampler): |
| """ |
| Wraps a BatchSampler, resampling from it until |
| a specified number of iterations have been sampled |
| """ |
|
|
| def __init__(self, batch_sampler, num_iterations, start_iter=0): |
| self.batch_sampler = batch_sampler |
| self.num_iterations = num_iterations |
| self.start_iter = start_iter |
|
|
| def __iter__(self): |
| iteration = self.start_iter |
| while iteration <= self.num_iterations: |
| |
| |
| |
| if hasattr(self.batch_sampler.sampler, "set_epoch"): |
| self.batch_sampler.sampler.set_epoch(iteration) |
| for batch in self.batch_sampler: |
| iteration += 1 |
| if iteration > self.num_iterations: |
| break |
| yield batch |
|
|
| def __len__(self): |
| return self.num_iterations |
|
|
|
|
| def make_batch_data_sampler(sampler, images_per_gpu, num_iters=None, start_iter=0): |
| batch_sampler = torch.utils.data.sampler.BatchSampler( |
| sampler, images_per_gpu, drop_last=False |
| ) |
| if num_iters is not None and num_iters >= 0: |
| batch_sampler = IterationBasedBatchSampler( |
| batch_sampler, num_iters, start_iter |
| ) |
| return batch_sampler |
|
|
|
|
| def make_data_sampler(dataset, shuffle, distributed): |
| if distributed: |
| return torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle) |
| if shuffle: |
| sampler = torch.utils.data.sampler.RandomSampler(dataset) |
| else: |
| sampler = torch.utils.data.sampler.SequentialSampler(dataset) |
| return sampler |
|
|
|
|
| def make_data_loader(args, yaml_file, is_distributed=True, |
| is_train=True, start_iter=0, scale_factor=1): |
|
|
| dataset = build_dataset(yaml_file, args, is_train=is_train, scale_factor=scale_factor) |
| logger = logging.getLogger(__name__) |
| if is_train==True: |
| shuffle = True |
| images_per_gpu = args.per_gpu_train_batch_size |
| images_per_batch = images_per_gpu * get_world_size() |
| iters_per_batch = len(dataset) // images_per_batch |
| num_iters = iters_per_batch * args.num_train_epochs |
| logger.info("Train with {} images per GPU.".format(images_per_gpu)) |
| logger.info("Total batch size {}".format(images_per_batch)) |
| logger.info("Total training steps {}".format(num_iters)) |
| else: |
| shuffle = False |
| images_per_gpu = args.per_gpu_eval_batch_size |
| num_iters = None |
| start_iter = 0 |
|
|
| sampler = make_data_sampler(dataset, shuffle, is_distributed) |
| batch_sampler = make_batch_data_sampler( |
| sampler, images_per_gpu, num_iters, start_iter |
| ) |
| data_loader = torch.utils.data.DataLoader( |
| dataset, num_workers=args.num_workers, batch_sampler=batch_sampler, |
| pin_memory=True, |
| ) |
| return data_loader |
|
|
|
|
| |
|
|
| def build_hand_dataset(yaml_file, args, is_train=True, scale_factor=1): |
| print(yaml_file) |
| if not op.isfile(yaml_file): |
| yaml_file = op.join(args.data_dir, yaml_file) |
| |
| assert op.isfile(yaml_file) |
| return HandMeshTSVYamlDataset(args, yaml_file, is_train, False, scale_factor) |
|
|
|
|
| def make_hand_data_loader(args, yaml_file, is_distributed=True, |
| is_train=True, start_iter=0, scale_factor=1): |
|
|
| dataset = build_hand_dataset(yaml_file, args, is_train=is_train, scale_factor=scale_factor) |
| logger = logging.getLogger(__name__) |
| if is_train==True: |
| shuffle = True |
| images_per_gpu = args.per_gpu_train_batch_size |
| images_per_batch = images_per_gpu * get_world_size() |
| iters_per_batch = len(dataset) // images_per_batch |
| num_iters = iters_per_batch * args.num_train_epochs |
| logger.info("Train with {} images per GPU.".format(images_per_gpu)) |
| logger.info("Total batch size {}".format(images_per_batch)) |
| logger.info("Total training steps {}".format(num_iters)) |
| else: |
| shuffle = False |
| images_per_gpu = args.per_gpu_eval_batch_size |
| num_iters = None |
| start_iter = 0 |
|
|
| sampler = make_data_sampler(dataset, shuffle, is_distributed) |
| batch_sampler = make_batch_data_sampler( |
| sampler, images_per_gpu, num_iters, start_iter |
| ) |
| data_loader = torch.utils.data.DataLoader( |
| dataset, num_workers=args.num_workers, batch_sampler=batch_sampler, |
| pin_memory=True, |
| ) |
| return data_loader |
|
|
|
|