| import logging
|
| import random
|
|
|
| import numpy as np
|
| import torch
|
| from omegaconf import DictConfig
|
| from torch.utils.data import DataLoader, Dataset
|
| from torch.utils.data.dataloader import default_collate
|
| from torch.utils.data.distributed import DistributedSampler
|
|
|
| from .eval.audiocaps import AudioCapsData
|
| from .eval.video_dataset import MovieGen, VGGSound
|
| from .extracted_audio import ExtractedAudio
|
| from .extracted_vgg import ExtractedVGG
|
| from .mm_dataset import MultiModalDataset
|
| from ..utils.dist_utils import local_rank
|
|
|
| log = logging.getLogger()
|
|
|
|
|
|
|
| def worker_init_fn(worker_id: int):
|
| worker_seed = torch.initial_seed() % (2**31) + worker_id + local_rank * 1000
|
| np.random.seed(worker_seed)
|
| random.seed(worker_seed)
|
| log.debug(f'Worker {worker_id} re-seeded with seed {worker_seed} in rank {local_rank}')
|
|
|
|
|
| def load_vgg_data(cfg: DictConfig, data_cfg: DictConfig) -> Dataset:
|
| dataset = ExtractedVGG(tsv_path=data_cfg.tsv,
|
| data_dim=cfg.data_dim,
|
| premade_mmap_dir=data_cfg.memmap_dir)
|
|
|
| return dataset
|
|
|
|
|
| def load_audio_data(cfg: DictConfig, data_cfg: DictConfig) -> Dataset:
|
| dataset = ExtractedAudio(tsv_path=data_cfg.tsv,
|
| data_dim=cfg.data_dim,
|
| premade_mmap_dir=data_cfg.memmap_dir)
|
|
|
| return dataset
|
|
|
|
|
| def setup_training_datasets(cfg: DictConfig) -> tuple[Dataset, DistributedSampler, DataLoader]:
|
| if cfg.mini_train:
|
| vgg = load_vgg_data(cfg, cfg.data.ExtractedVGG_val)
|
| audiocaps = load_audio_data(cfg, cfg.data.AudioCaps)
|
| dataset = MultiModalDataset([vgg], [audiocaps])
|
| if cfg.example_train:
|
| video = load_vgg_data(cfg, cfg.data.Example_video)
|
| audio = load_audio_data(cfg, cfg.data.Example_audio)
|
| dataset = MultiModalDataset([video], [audio])
|
| else:
|
|
|
| freesound = load_audio_data(cfg, cfg.data.FreeSound)
|
| vgg = load_vgg_data(cfg, cfg.data.ExtractedVGG)
|
| audiocaps = load_audio_data(cfg, cfg.data.AudioCaps)
|
| audioset_sl = load_audio_data(cfg, cfg.data.AudioSetSL)
|
| bbcsound = load_audio_data(cfg, cfg.data.BBCSound)
|
| clotho = load_audio_data(cfg, cfg.data.Clotho)
|
| dataset = MultiModalDataset([vgg] * cfg.vgg_oversample_rate,
|
| [audiocaps, audioset_sl, bbcsound, freesound, clotho])
|
|
|
| batch_size = cfg.batch_size
|
| num_workers = cfg.num_workers
|
| pin_memory = cfg.pin_memory
|
| sampler, loader = construct_loader(dataset,
|
| batch_size,
|
| num_workers,
|
| shuffle=True,
|
| drop_last=True,
|
| pin_memory=pin_memory)
|
|
|
| return dataset, sampler, loader
|
|
|
|
|
| def setup_test_datasets(cfg):
|
| dataset = load_vgg_data(cfg, cfg.data.ExtractedVGG_test)
|
|
|
| batch_size = cfg.batch_size
|
| num_workers = cfg.num_workers
|
| pin_memory = cfg.pin_memory
|
| sampler, loader = construct_loader(dataset,
|
| batch_size,
|
| num_workers,
|
| shuffle=False,
|
| drop_last=False,
|
| pin_memory=pin_memory)
|
|
|
| return dataset, sampler, loader
|
|
|
|
|
| def setup_val_datasets(cfg: DictConfig) -> tuple[Dataset, DataLoader, DataLoader]:
|
| if cfg.example_train:
|
| dataset = load_vgg_data(cfg, cfg.data.Example_video)
|
| else:
|
| dataset = load_vgg_data(cfg, cfg.data.ExtractedVGG_val)
|
|
|
| val_batch_size = cfg.batch_size
|
| val_eval_batch_size = cfg.eval_batch_size
|
| num_workers = cfg.num_workers
|
| pin_memory = cfg.pin_memory
|
| _, val_loader = construct_loader(dataset,
|
| val_batch_size,
|
| num_workers,
|
| shuffle=False,
|
| drop_last=False,
|
| pin_memory=pin_memory)
|
| _, eval_loader = construct_loader(dataset,
|
| val_eval_batch_size,
|
| num_workers,
|
| shuffle=False,
|
| drop_last=False,
|
| pin_memory=pin_memory)
|
|
|
| return dataset, val_loader, eval_loader
|
|
|
|
|
| def setup_eval_dataset(dataset_name: str, cfg: DictConfig) -> tuple[Dataset, DataLoader]:
|
| if dataset_name.startswith('audiocaps_full'):
|
| dataset = AudioCapsData(cfg.eval_data.AudioCaps_full.audio_path,
|
| cfg.eval_data.AudioCaps_full.csv_path)
|
| elif dataset_name.startswith('audiocaps'):
|
| dataset = AudioCapsData(cfg.eval_data.AudioCaps.audio_path,
|
| cfg.eval_data.AudioCaps.csv_path)
|
| elif dataset_name.startswith('moviegen'):
|
| dataset = MovieGen(cfg.eval_data.MovieGen.video_path,
|
| cfg.eval_data.MovieGen.jsonl_path,
|
| duration_sec=cfg.duration_s)
|
| elif dataset_name.startswith('vggsound'):
|
| dataset = VGGSound(cfg.eval_data.VGGSound.video_path,
|
| cfg.eval_data.VGGSound.csv_path,
|
| duration_sec=cfg.duration_s)
|
| else:
|
| raise ValueError(f'Invalid dataset name: {dataset_name}')
|
|
|
| batch_size = cfg.batch_size
|
| num_workers = cfg.num_workers
|
| pin_memory = cfg.pin_memory
|
| _, loader = construct_loader(dataset,
|
| batch_size,
|
| num_workers,
|
| shuffle=False,
|
| drop_last=False,
|
| pin_memory=pin_memory,
|
| error_avoidance=True)
|
| return dataset, loader
|
|
|
|
|
| def error_avoidance_collate(batch):
|
| batch = list(filter(lambda x: x is not None, batch))
|
| return default_collate(batch)
|
|
|
|
|
| def construct_loader(dataset: Dataset,
|
| batch_size: int,
|
| num_workers: int,
|
| *,
|
| shuffle: bool = True,
|
| drop_last: bool = True,
|
| pin_memory: bool = False,
|
| error_avoidance: bool = False) -> tuple[DistributedSampler, DataLoader]:
|
| train_sampler = DistributedSampler(dataset, rank=local_rank, shuffle=shuffle)
|
| train_loader = DataLoader(dataset,
|
| batch_size,
|
| sampler=train_sampler,
|
| num_workers=num_workers,
|
| worker_init_fn=worker_init_fn,
|
| drop_last=drop_last,
|
| persistent_workers=num_workers > 0,
|
| pin_memory=pin_memory,
|
| collate_fn=error_avoidance_collate if error_avoidance else None)
|
| return train_sampler, train_loader
|
|
|