import torch from torch.utils.data.distributed import DistributedSampler from .batch_sampler import BucketSampler from .dataset import LRCRecordLoader from .dataset import Dataset, collate_func from libs.utils.comm import distributed, get_rank, get_world_size from . import transform as T def create_train_dataloader(vocab, lrcs_path, num_workers, max_batch_size, max_pixel_nums, bucket_seps, data_root_dir): loaders = list() for lrc_path in lrcs_path: loader = LRCRecordLoader(lrc_path, data_root_dir) loaders.append(loader) transforms = T.Compose([ T.TableToLabel(vocab), T.CalRowColSpans(), T.CalCellSpans(), T.CalHeadBodyDivide(), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) dataset = Dataset(loaders, transforms) batch_sampler = BucketSampler(dataset, get_world_size(), get_rank(), max_pixel_nums=max_pixel_nums, max_batch_size=max_batch_size,seps=bucket_seps) dataloader = torch.utils.data.DataLoader( dataset, num_workers=num_workers, collate_fn=collate_func, batch_sampler=batch_sampler ) return dataloader def create_valid_dataloader(vocab, lrc_path, num_workers, batch_size, data_root_dir): loader = LRCRecordLoader(lrc_path, data_root_dir) transforms = T.Compose([ T.TableToLabel(vocab), T.CalRowColSpans(), T.CalCellSpans(), T.CalHeadBodyDivide(), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) dataset = Dataset([loader], transforms) if distributed(): sampler = DistributedSampler(dataset, get_world_size(), get_rank(), True) dataloader = torch.utils.data.DataLoader( dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=collate_func, sampler=sampler, drop_last=False ) else: dataloader = torch.utils.data.DataLoader( dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=collate_func, shuffle=False, drop_last=False ) return dataloader