|
|
import torch |
|
|
from torch.utils.data import ConcatDataset, DataLoader |
|
|
from torchvision import transforms |
|
|
from torchvision.transforms import InterpolationMode |
|
|
from dataset.it_dataset import ITImgTrainDataset, ITVidTrainDataset |
|
|
|
|
|
|
|
|
def get_media_type(dataset_config): |
|
|
if len(dataset_config) == 3 and dataset_config[2] == "video": |
|
|
return "video" |
|
|
elif dataset_config[-1] == "only_video": |
|
|
return "only_video" |
|
|
else: |
|
|
return "image" |
|
|
|
|
|
|
|
|
def create_dataset(dataset_type, config): |
|
|
if "clip" in config.model.get("vit_model", 'vit'): |
|
|
mean = (0.485, 0.456, 0.406) |
|
|
std = (0.229, 0.224, 0.225) |
|
|
else: |
|
|
vision_enc_name = config.model.vision_encoder.name |
|
|
if "swin" in vision_enc_name or "vit" in vision_enc_name: |
|
|
mean = (0.485, 0.456, 0.406) |
|
|
std = (0.229, 0.224, 0.225) |
|
|
elif "beit" in vision_enc_name: |
|
|
mean = (0.5, 0.5, 0.5) |
|
|
std = (0.5, 0.5, 0.5) |
|
|
elif "clip" in vision_enc_name: |
|
|
mean = (0.48145466, 0.4578275, 0.40821073) |
|
|
std = (0.26862954, 0.26130258, 0.27577711) |
|
|
else: |
|
|
raise ValueError |
|
|
|
|
|
normalize = transforms.Normalize(mean, std) |
|
|
|
|
|
|
|
|
|
|
|
type_transform = transforms.Lambda(lambda x: x.float().div(255.0)) |
|
|
|
|
|
if config.inputs.video_input.random_aug: |
|
|
aug_transform = transforms.RandAugment() |
|
|
else: |
|
|
aug_transform = transforms.Lambda(lambda x: x) |
|
|
|
|
|
train_transform = transforms.Compose( |
|
|
[ |
|
|
aug_transform, |
|
|
transforms.RandomResizedCrop( |
|
|
config.inputs.image_res, |
|
|
scale=(0.5, 1.0), |
|
|
interpolation=InterpolationMode.BICUBIC, |
|
|
), |
|
|
transforms.RandomHorizontalFlip(), |
|
|
type_transform, |
|
|
normalize, |
|
|
] |
|
|
) |
|
|
test_transform = transforms.Compose( |
|
|
[ |
|
|
transforms.Resize( |
|
|
(config.inputs.image_res, config.inputs.image_res), |
|
|
interpolation=InterpolationMode.BICUBIC, |
|
|
), |
|
|
type_transform, |
|
|
normalize, |
|
|
] |
|
|
) |
|
|
|
|
|
video_reader_type = config.inputs.video_input.get("video_reader_type", "decord") |
|
|
video_only_dataset_kwargs_train = dict( |
|
|
video_reader_type=video_reader_type, |
|
|
sample_type=config.inputs.video_input.sample_type, |
|
|
num_frames=config.inputs.video_input.num_frames, |
|
|
num_tries=3, |
|
|
) |
|
|
|
|
|
if dataset_type == "pt_train": |
|
|
raise ValueError("NOT PRETRAINING YET") |
|
|
elif dataset_type in ["it_train"]: |
|
|
|
|
|
train_files = ( |
|
|
[config.train_file] if isinstance(config.train_file[0], str) else config.train_file |
|
|
) |
|
|
train_media_types = sorted(list({get_media_type(e) for e in train_files})) |
|
|
|
|
|
train_datasets = [] |
|
|
for m in train_media_types: |
|
|
dataset_cls = ITImgTrainDataset if m == "image" else ITVidTrainDataset |
|
|
|
|
|
_train_files = [e for e in train_files if get_media_type(e) == m] |
|
|
|
|
|
datasets = [] |
|
|
for train_file in _train_files: |
|
|
dataset_kwargs = dict( |
|
|
ann_file=train_file, |
|
|
transform=train_transform, |
|
|
mm_alone=config.preprocess.get("mm_alone", True), |
|
|
add_second_msg=config.preprocess.get("add_second_msg", True), |
|
|
skip_short_sample=config.preprocess.get("skip_short_sample", False), |
|
|
clip_transform=config.preprocess.get("clip_transform", False), |
|
|
random_shuffle=config.preprocess.get("random_shuffle", True), |
|
|
system=config.preprocess.get("system", ""), |
|
|
role=config.preprocess.get('roles', ("Human", "Assistant")), |
|
|
end_signal=config.preprocess.get('end_signal', "###"), |
|
|
begin_signal=config.preprocess.get('begin_signal', ""), |
|
|
) |
|
|
if m == "video": |
|
|
video_only_dataset_kwargs_train.update({ |
|
|
"start_token": config.model.get("start_token", "<Video>"), |
|
|
"end_token": config.model.get("end_token", "</Video>"), |
|
|
}) |
|
|
dataset_kwargs.update(video_only_dataset_kwargs_train) |
|
|
if "tgif" in train_file[1]: |
|
|
video_only_dataset_kwargs_train.update({ |
|
|
"video_reader_type": "gif" |
|
|
}) |
|
|
dataset_kwargs.update(video_only_dataset_kwargs_train) |
|
|
elif "webvid" in train_file[1]: |
|
|
video_only_dataset_kwargs_train.update({ |
|
|
"video_reader_type": "hdfs" |
|
|
}) |
|
|
else: |
|
|
video_only_dataset_kwargs_train.update({ |
|
|
"video_reader_type": "decord" |
|
|
}) |
|
|
dataset_kwargs.update(video_only_dataset_kwargs_train) |
|
|
datasets.append(dataset_cls(**dataset_kwargs)) |
|
|
dataset = ConcatDataset(datasets) |
|
|
train_datasets.append(dataset) |
|
|
return train_datasets |
|
|
|
|
|
|
|
|
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): |
|
|
loaders = [] |
|
|
for dataset, sampler, bs, n_worker, is_train, collate_fn in zip( |
|
|
datasets, samplers, batch_size, num_workers, is_trains, collate_fns |
|
|
): |
|
|
if is_train: |
|
|
shuffle = sampler is None |
|
|
drop_last = True |
|
|
else: |
|
|
shuffle = False |
|
|
drop_last = False |
|
|
loader = DataLoader( |
|
|
dataset, |
|
|
batch_size=bs, |
|
|
num_workers=n_worker, |
|
|
pin_memory=False, |
|
|
sampler=sampler, |
|
|
shuffle=shuffle, |
|
|
collate_fn=collate_fn, |
|
|
drop_last=drop_last, |
|
|
persistent_workers=True if n_worker > 0 else False, |
|
|
) |
|
|
loaders.append(loader) |
|
|
return loaders |
|
|
|
|
|
|