Spaces:
Runtime error
Runtime error
| # Copyright 2024 EPFL and Apple Inc. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import copy | |
| import torch | |
| import yaml | |
| import fourm.utils as utils | |
| from fourm.data import (CenterCropImageAugmenter, EmptyAugmenter, | |
| PreTokenizedImageAugmenter,RandomCropImageAugmenter, build_fm_pretraining_dataset, | |
| build_huggingface_pretraining_dataloader, | |
| build_wds_fm_pretraining_dataloader) | |
| from fourm.data.modality_transforms import CaptionTransform | |
| from fourm.data.modality_info import MODALITY_TRANSFORMS | |
| def setup_sampling_mod_info(dataset_config, modality_info): | |
| # Subset of modality info for each dataset | |
| # Input and output modalities for one dataset | |
| in_domains = sorted(dataset_config['in_domains'].split('-')) | |
| out_domains = sorted(dataset_config['out_domains'].split('-')) | |
| all_domains = sorted(list(set(in_domains) | set(out_domains))) | |
| mod_info = copy.deepcopy(modality_info) | |
| mod_info = {mod: mod_info[mod] for mod in all_domains} | |
| # Dirichlet concentration parameter (Alpha) | |
| if dataset_config.get('alphas_config', None) is None: | |
| for mod in mod_info: | |
| mod_info[mod]["input_alphas"] = [0.] | |
| mod_info[mod]["target_alphas"] = [0.] | |
| if 'input_alphas' in dataset_config: | |
| input_alphas = dataset_config['input_alphas'].split('-') | |
| if len(input_alphas) == 1: | |
| input_alphas = [float(input_alphas[0])] * len(in_domains) | |
| else: | |
| input_alphas = [float(alpha) for alpha in input_alphas] | |
| for mod, alpha in zip(in_domains, input_alphas): | |
| mod_info[mod]['input_alphas'] = [alpha] | |
| if 'target_alphas' in dataset_config: | |
| target_alphas = dataset_config['target_alphas'].split('-') | |
| if len(target_alphas) == 1: | |
| target_alphas = [float(target_alphas[0])] * len(out_domains) | |
| else: | |
| target_alphas = [float(alpha) for alpha in target_alphas] | |
| for mod, alpha in zip(out_domains, target_alphas): | |
| mod_info[mod]["target_alphas"] = [alpha] | |
| sampling_weights = None | |
| else: | |
| print(f"Loading alphas config from: {dataset_config['alphas_config']}") | |
| with open(dataset_config['alphas_config'], "r") as f: | |
| alphas_config = yaml.safe_load(f) | |
| if 'sampling_weights' in alphas_config: | |
| sampling_weights = alphas_config['sampling_weights'] | |
| alphas_config = alphas_config['alphas_mixture'] | |
| else: | |
| sampling_weights = None | |
| for mod in mod_info: | |
| mod_info[mod]["input_alphas"] = alphas_config[mod]["input_alphas"] | |
| mod_info[mod]["target_alphas"] = alphas_config[mod]["target_alphas"] | |
| if modality_info[mod]['type'] in ['seq', 'seq_emb', 'seq_token']: | |
| mod_info[mod]['keep'] = alphas_config[mod]['keep'] | |
| return mod_info, sampling_weights | |
| def get_train_dataloader(dataset_config, modality_info, sampling_weights, text_tokenizer, input_size, | |
| num_input_tokens, num_target_tokens, min_input_tokens, min_target_tokens, | |
| num_tasks, num_workers, dataset_batch_size=None, epoch_size=None): | |
| in_domains = sorted(list(dataset_config['in_domains'].split('-'))) | |
| out_domains = sorted(list(dataset_config['out_domains'].split('-'))) | |
| all_domains = sorted(list(set(in_domains) | set(out_domains))) | |
| modality_transforms = MODALITY_TRANSFORMS | |
| if 'caption' in modality_transforms: | |
| modality_transforms['caption'] = CaptionTransform( | |
| aligned_captions=dataset_config.get('aligned_captions', True) | |
| ) | |
| if dataset_config['type'] == 'multimodal': | |
| is_pretokenized = any([modality_info[mod].get('pretokenized', False) for mod in modality_info]) | |
| if is_pretokenized: | |
| # Multi-modal training data augmentation (uses pre-tokenized data augmentation) | |
| image_augmenter = PreTokenizedImageAugmenter( | |
| target_size=input_size, | |
| no_aug=(not dataset_config.get('tok_train_aug', True)), | |
| main_domain=dataset_config['main_augment_domain'] | |
| ) | |
| else: | |
| image_augmenter = RandomCropImageAugmenter( | |
| target_size=input_size, | |
| hflip=dataset_config.get('hflip'), | |
| crop_scale=tuple(dataset_config.get('crop_scale')), | |
| crop_ratio=tuple(dataset_config.get('crop_ratio')), | |
| ) | |
| # Input and target token ranges | |
| num_input_tokens = dataset_config.get('num_input_tokens', num_input_tokens) | |
| num_target_tokens = dataset_config.get('num_target_tokens', num_target_tokens) | |
| min_input_tokens = dataset_config.get('min_input_tokens', min_input_tokens) | |
| min_target_tokens = dataset_config.get('min_target_tokens', min_target_tokens) | |
| min_input_tokens = num_input_tokens if min_input_tokens is None else min_input_tokens | |
| min_target_tokens = num_target_tokens if min_target_tokens is None else min_target_tokens | |
| if dataset_config['use_wds']: | |
| # Using webdataset | |
| loader = build_wds_fm_pretraining_dataloader( | |
| data_path=dataset_config['data_path'], all_domains=all_domains, | |
| modality_info=modality_info, modality_transforms=modality_transforms, | |
| image_augmenter=image_augmenter, text_tokenizer=text_tokenizer, | |
| input_tokens_range=(min_input_tokens, num_input_tokens), | |
| target_tokens_range=(min_target_tokens, num_target_tokens), | |
| num_gpus=num_tasks, num_workers=num_workers, | |
| batch_size=dataset_batch_size, epoch_size=epoch_size, | |
| modality_name_map=dataset_config.get('modality_name_map', None), | |
| shuffle_buffer_load=dataset_config.get('wds_shuffle_buffer_tar', 1_000), | |
| shuffle_buffer_repeat=dataset_config.get('wds_shuffle_buffer_repeat', 1_000), | |
| n_repeats=dataset_config.get('wds_n_repeats', 1), | |
| sampling_weights=sampling_weights, | |
| ) | |
| else: | |
| dataset_train = build_fm_pretraining_dataset( | |
| data_path=dataset_config['data_path'], | |
| all_domains=all_domains, modality_info=modality_info, modality_transforms=modality_transforms, | |
| image_augmenter=image_augmenter, text_tokenizer=text_tokenizer, | |
| input_tokens_range=(min_input_tokens, num_input_tokens), | |
| target_tokens_range=(min_target_tokens, num_target_tokens) | |
| ) | |
| sampler_train = torch.utils.data.DistributedSampler( | |
| dataset_train, num_replicas=num_tasks, rank=utils.get_rank(), shuffle=True, drop_last=True, | |
| ) | |
| # DataLoader has batch size 1 as it then gets collated through the Mixture dataloader | |
| loader = torch.utils.data.DataLoader( | |
| dataset_train, sampler=sampler_train, | |
| batch_size=1, num_workers=0, | |
| pin_memory=False, drop_last=True, | |
| collate_fn=lambda x: x[0], | |
| ) | |
| elif dataset_config['type'] == 'huggingface': | |
| # Input and target token ranges | |
| num_input_tokens = dataset_config.get('num_input_tokens', num_input_tokens) | |
| num_target_tokens = dataset_config.get('num_target_tokens', num_target_tokens) | |
| if dataset_config.get('use_wds', False): | |
| raise NotImplementedError('Webdataset not yet implemented for huggingface datasets.') | |
| else: | |
| loader = build_huggingface_pretraining_dataloader( | |
| data_path=dataset_config['data_path'], all_domains=all_domains, | |
| modality_info=modality_info, modality_transforms=modality_transforms, | |
| image_augmenter=EmptyAugmenter(), text_tokenizer=text_tokenizer, | |
| input_tokens_range=(num_input_tokens, num_input_tokens), | |
| target_tokens_range=(num_target_tokens, num_target_tokens), | |
| num_gpus=num_tasks, num_workers=num_workers, | |
| batch_size=dataset_batch_size, epoch_size=epoch_size, | |
| split='train', streaming=True, rename_text_to_caption=True, | |
| shuffle_buffer_load=dataset_config.get('shuffle_buffer_load', 1_000), | |
| shuffle_seed=0, | |
| ) | |
| else: | |
| raise NotImplementedError(f'Dataset type {dataset_config["type"]} not implemented.') | |
| return loader | |
| def cfgs_get(key, val_config, dataset_name, train_configs, default=None): | |
| """ Try to retrieve a key from the validation set config. | |
| If it does not exist, default to retrieving it from the train set config | |
| with the same dataset name. | |
| """ | |
| return val_config.get(key, train_configs[dataset_name].get(key, default)) | |
| def get_val_dataloader(dataset_config, dataset_name, train_configs, modality_info, sampling_weights, text_tokenizer, | |
| input_size, num_input_tokens, num_target_tokens, min_input_tokens, min_target_tokens, | |
| fixed_eval, fixed_eval_input_tokens, fixed_eval_target_tokens, | |
| dist_eval, num_tasks, num_workers, batch_size, pin_mem): | |
| in_domains = sorted(list(cfgs_get('in_domains', dataset_config, dataset_name, train_configs).split('-'))) | |
| out_domains = sorted(list(cfgs_get('out_domains', dataset_config, dataset_name, train_configs).split('-'))) | |
| all_domains = sorted(list(set(in_domains) | set(out_domains))) | |
| modality_transforms = MODALITY_TRANSFORMS | |
| if 'caption' in modality_transforms: | |
| modality_transforms['caption'] = CaptionTransform( | |
| aligned_captions=cfgs_get('aligned_captions', dataset_config, dataset_name, train_configs, True) | |
| ) | |
| dataset_type = cfgs_get('type', dataset_config, dataset_name, train_configs) | |
| if dataset_type == 'multimodal': | |
| main_augment_domain = cfgs_get('main_augment_domain', dataset_config, dataset_name, train_configs) | |
| is_pretokenized = any([modality_info[mod].get('pretokenized', False) for mod in modality_info]) | |
| if is_pretokenized: | |
| eval_image_augmenter = PreTokenizedImageAugmenter( | |
| target_size=input_size, no_aug=True, main_domain=main_augment_domain | |
| ) | |
| else: | |
| eval_image_augmenter = CenterCropImageAugmenter( | |
| target_size=input_size, main_domain=main_augment_domain | |
| ) | |
| if fixed_eval: | |
| input_tokens_range=(fixed_eval_input_tokens, fixed_eval_input_tokens) | |
| target_tokens_range=(fixed_eval_target_tokens, fixed_eval_target_tokens) | |
| else: | |
| # Input and target token ranges | |
| num_input_tokens = dataset_config.get('num_input_tokens', num_input_tokens) | |
| num_target_tokens = dataset_config.get('num_target_tokens', num_target_tokens) | |
| min_input_tokens = dataset_config.get('min_input_tokens', min_input_tokens) | |
| min_target_tokens = dataset_config.get('min_target_tokens', min_target_tokens) | |
| min_input_tokens = num_input_tokens if min_input_tokens is None else min_input_tokens | |
| min_target_tokens = num_target_tokens if min_target_tokens is None else min_target_tokens | |
| input_tokens_range = (min_input_tokens, num_input_tokens) | |
| target_tokens_range = (min_target_tokens, num_target_tokens) | |
| dataset_val = build_fm_pretraining_dataset( | |
| data_path=cfgs_get('data_path', dataset_config, dataset_name, train_configs), | |
| all_domains=all_domains, modality_info=modality_info, modality_transforms=modality_transforms, | |
| image_augmenter=eval_image_augmenter, text_tokenizer=text_tokenizer, | |
| input_tokens_range=input_tokens_range, target_tokens_range=target_tokens_range | |
| ) | |
| print("Warning: Eval stats may vary slightly as the masking applied on images is random.") | |
| if dist_eval: | |
| if len(dataset_val) % num_tasks != 0: | |
| print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' | |
| 'This will slightly alter validation results as extra duplicate entries are added to achieve ' | |
| 'equal num of samples per-process.') | |
| sampler_val = torch.utils.data.DistributedSampler( | |
| dataset_val, num_replicas=num_tasks, rank=utils.get_rank(), shuffle=False) | |
| else: | |
| sampler_val = torch.utils.data.SequentialSampler(dataset_val) | |
| loader = torch.utils.data.DataLoader( | |
| dataset_val, sampler=sampler_val, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| pin_memory=pin_mem, | |
| drop_last=False, | |
| ) | |
| elif dataset_type == 'huggingface': | |
| if fixed_eval: | |
| input_tokens_range=(fixed_eval_input_tokens, fixed_eval_input_tokens) | |
| target_tokens_range=(fixed_eval_target_tokens, fixed_eval_target_tokens) | |
| else: | |
| # Input and target token ranges | |
| num_input_tokens = dataset_config.get('num_input_tokens', num_input_tokens) | |
| num_target_tokens = dataset_config.get('num_target_tokens', num_target_tokens) | |
| input_tokens_range = (num_input_tokens, num_input_tokens) | |
| target_tokens_range = (num_target_tokens, num_target_tokens) | |
| loader = build_huggingface_pretraining_dataloader( | |
| data_path=cfgs_get('data_path', dataset_config, dataset_name, train_configs), | |
| all_domains=all_domains, modality_info=modality_info, modality_transforms=modality_transforms, | |
| image_augmenter=EmptyAugmenter(), text_tokenizer=text_tokenizer, | |
| input_tokens_range=input_tokens_range, target_tokens_range=target_tokens_range, | |
| num_gpus=num_tasks, num_workers=num_workers, | |
| batch_size=batch_size, epoch_size=None, | |
| split='validation', streaming=True, rename_text_to_caption=True, | |
| shuffle_buffer_load=cfgs_get('shuffle_buffer_load', dataset_config, dataset_name, train_configs, 1_000), | |
| shuffle_seed=0, | |
| ) | |
| else: | |
| raise NotImplementedError(f'Dataset type {dataset_type} not implemented.') | |
| return loader |