| import sys |
| import os |
| import importlib |
|
|
| from omegaconf import OmegaConf |
| from tqdm.auto import tqdm |
|
|
| import torch |
|
|
| sys.path.append(os.path.join(os.path.dirname(__file__),'../..')) |
|
|
|
|
|
|
| def get_obj_from_str(string, reload=False, invalidate_cache=True): |
| module, cls = string.rsplit(".", 1) |
| if invalidate_cache: |
| importlib.invalidate_caches() |
| if reload: |
| module_imp = importlib.import_module(module) |
| importlib.reload(module_imp) |
| return getattr(importlib.import_module(module, package=None), cls) |
|
|
|
|
| def prepare_dataloader_for_rank(config, global_rank, num_processes=-1, repeat_cp_size=1): |
| """ Get the dataloader given config and the current global rank. |
| "dataset_setting" provides the list of dataset configs |
| "rank_index_map" provides how to distribute the config across ranks |
| """ |
| |
| if repeat_cp_size > 1: |
| print(f'before repeat config.rank_index_map: {config.rank_index_map}') |
| repeated_rank_index_map = [element for element in config.rank_index_map for _ in range(repeat_cp_size)] |
| config.rank_index_map = repeated_rank_index_map |
| print(f'after repeat repeated_rank_index_map: {config.rank_index_map}') |
|
|
| |
| num_total_indices = len(config.rank_index_map) |
| dataset_index = config.rank_index_map[global_rank % num_total_indices] |
|
|
| |
| num_partitions = 1 |
| partition_id = 0 |
| if num_processes > 0: |
| rank_to_dataset_index_map = list(config.rank_index_map) * num_processes |
| rank_to_dataset_index_map = rank_to_dataset_index_map[:num_processes] |
| num_partitions = rank_to_dataset_index_map.count(dataset_index) |
| partition_id = rank_to_dataset_index_map[:global_rank].count(dataset_index) |
| print(f'rank_to_dataset_index_map: {rank_to_dataset_index_map}') |
| print(f'dataset_index: {dataset_index} partition_id: {partition_id} num_partitions: {num_partitions} ') |
|
|
| |
| sum_loss_weight = 0.0 |
| for i in range(num_total_indices): |
| dataset_setting = config.dataset_setting[config.rank_index_map[i]] |
| sum_loss_weight += dataset_setting.get("loss_weight", 1.0) |
| loss_weight_scale = float(num_total_indices) / sum_loss_weight |
|
|
| |
| dataset_setting = config.dataset_setting[dataset_index] |
| loss_weight = dataset_setting.get("loss_weight", 1.0) * loss_weight_scale |
| print(f'global_rank: {global_rank} -- dataset_index: {dataset_index} - loss_weight_scale: {loss_weight_scale} - loss weight: {loss_weight} - dataset_setting: {dataset_setting}') |
|
|
| |
| utils_prompt_module = importlib.import_module(dataset_setting.get_prompt_module) |
| get_prompt_func = getattr(utils_prompt_module, dataset_setting.get_prompt_func) |
| get_prompt_frame_spans_func = None |
| if hasattr(dataset_setting, "get_prompt_frame_spans_func"): |
| get_prompt_frame_spans_func = getattr(utils_prompt_module, dataset_setting.get_prompt_frame_spans_func) |
|
|
| |
| dataset_kwargs = dataset_setting.get("dataset_kwargs", dict()) |
|
|
| |
| assert hasattr(dataset_kwargs, "bucket_configs") |
| bucket_configs = dataset_kwargs.get("bucket_configs", dict()) |
|
|
| dataset = get_obj_from_str(dataset_setting.dataset_target)( |
| get_prompt_func=get_prompt_func, |
| get_prompt_frame_spans_func=get_prompt_frame_spans_func, |
| partition_id=partition_id, |
| num_partitions=num_partitions, |
| **dataset_kwargs |
| ) |
|
|
| |
| dataloader_kwargs = dataset_setting.get("dataloader_kwargs", dict()) |
| dataloader = torch.utils.data.DataLoader( |
| dataset, |
| **dataloader_kwargs, |
| shuffle=False, |
| pin_memory=True, |
| drop_last=True, |
| collate_fn = dataset.collate_fn if hasattr(dataset,"collate_fn") else None, |
| ) |
|
|
| return dataloader, loss_weight, bucket_configs |
|
|
|
|
|
|
| if __name__ == '__main__': |
| |
| example_config_path = "configs/train_t2v_opensora_v2_ms_long32_hq400.yaml" |
| config = OmegaConf.load(example_config_path) |
|
|
| dataloader = prepare_dataloader_for_rank(config.video_training_data_config, global_rank=7, num_processes=28) |
|
|
| num_train_steps = 1000 |
| progress_bar = tqdm(range(0, num_train_steps)) |
|
|
| |
| |
|
|
| |
| for step, batch in enumerate(dataloader): |
| progress_bar.update(1) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| if step >= num_train_steps: |
| break |
|
|