training_sem / libs /data /__init__.py
kai-2054's picture
Initial commit: add code
cb0ad2d
raw
history blame
2.26 kB
import torch
from torch.utils.data.distributed import DistributedSampler
from .batch_sampler import BucketSampler
from .dataset import LRCRecordLoader
from .dataset import Dataset, collate_func
from libs.utils.comm import distributed, get_rank, get_world_size
from . import transform as T
def create_train_dataloader(vocab, lrcs_path, num_workers, max_batch_size, max_pixel_nums, bucket_seps, data_root_dir):
loaders = list()
for lrc_path in lrcs_path:
loader = LRCRecordLoader(lrc_path, data_root_dir)
loaders.append(loader)
transforms = T.Compose([
T.TableToLabel(vocab),
T.CalRowColSpans(),
T.CalCellSpans(),
T.CalHeadBodyDivide(),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
dataset = Dataset(loaders, transforms)
batch_sampler = BucketSampler(dataset, get_world_size(), get_rank(), max_pixel_nums=max_pixel_nums, max_batch_size=max_batch_size,seps=bucket_seps)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
collate_fn=collate_func,
batch_sampler=batch_sampler
)
return dataloader
def create_valid_dataloader(vocab, lrc_path, num_workers, batch_size, data_root_dir):
loader = LRCRecordLoader(lrc_path, data_root_dir)
transforms = T.Compose([
T.TableToLabel(vocab),
T.CalRowColSpans(),
T.CalCellSpans(),
T.CalHeadBodyDivide(),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
dataset = Dataset([loader], transforms)
if distributed():
sampler = DistributedSampler(dataset, get_world_size(), get_rank(), True)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_size=batch_size,
collate_fn=collate_func,
sampler=sampler,
drop_last=False
)
else:
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_size=batch_size,
collate_fn=collate_func,
shuffle=False,
drop_last=False
)
return dataloader