File size: 5,556 Bytes
01c7703 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import sys
import os
import importlib
from omegaconf import OmegaConf
from tqdm.auto import tqdm
import torch
sys.path.append(os.path.join(os.path.dirname(__file__),'../..'))
def get_obj_from_str(string, reload=False, invalidate_cache=True):
module, cls = string.rsplit(".", 1)
if invalidate_cache:
importlib.invalidate_caches()
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def prepare_dataloader_for_rank(config, global_rank, num_processes=-1, repeat_cp_size=1):
""" Get the dataloader given config and the current global rank.
"dataset_setting" provides the list of dataset configs
"rank_index_map" provides how to distribute the config across ranks
"""
# repeat each elements in CP; [a b c] --> [a a ... b b ... c c ...]
if repeat_cp_size > 1:
print(f'before repeat config.rank_index_map: {config.rank_index_map}')
repeated_rank_index_map = [element for element in config.rank_index_map for _ in range(repeat_cp_size)]
config.rank_index_map = repeated_rank_index_map
print(f'after repeat repeated_rank_index_map: {config.rank_index_map}')
# get the dataset index
num_total_indices = len(config.rank_index_map)
dataset_index = config.rank_index_map[global_rank % num_total_indices]
# get the correct partition
num_partitions = 1
partition_id = 0
if num_processes > 0:
rank_to_dataset_index_map = list(config.rank_index_map) * num_processes
rank_to_dataset_index_map = rank_to_dataset_index_map[:num_processes]
num_partitions = rank_to_dataset_index_map.count(dataset_index)
partition_id = rank_to_dataset_index_map[:global_rank].count(dataset_index)
print(f'rank_to_dataset_index_map: {rank_to_dataset_index_map}')
print(f'dataset_index: {dataset_index} partition_id: {partition_id} num_partitions: {num_partitions} ')
# get the loss weight scale factor to normalize loss weight to 1.0
sum_loss_weight = 0.0
for i in range(num_total_indices):
dataset_setting = config.dataset_setting[config.rank_index_map[i]]
sum_loss_weight += dataset_setting.get("loss_weight", 1.0)
loss_weight_scale = float(num_total_indices) / sum_loss_weight
# fetch the config
dataset_setting = config.dataset_setting[dataset_index]
loss_weight = dataset_setting.get("loss_weight", 1.0) * loss_weight_scale
print(f'global_rank: {global_rank} -- dataset_index: {dataset_index} - loss_weight_scale: {loss_weight_scale} - loss weight: {loss_weight} - dataset_setting: {dataset_setting}')
# set prompt function
utils_prompt_module = importlib.import_module(dataset_setting.get_prompt_module)
get_prompt_func = getattr(utils_prompt_module, dataset_setting.get_prompt_func)
get_prompt_frame_spans_func = None
if hasattr(dataset_setting, "get_prompt_frame_spans_func"):
get_prompt_frame_spans_func = getattr(utils_prompt_module, dataset_setting.get_prompt_frame_spans_func)
# get dataset from setting
dataset_kwargs = dataset_setting.get("dataset_kwargs", dict())
# get bucket configs
assert hasattr(dataset_kwargs, "bucket_configs")
bucket_configs = dataset_kwargs.get("bucket_configs", dict())
dataset = get_obj_from_str(dataset_setting.dataset_target)(
get_prompt_func=get_prompt_func,
get_prompt_frame_spans_func=get_prompt_frame_spans_func,
partition_id=partition_id,
num_partitions=num_partitions,
**dataset_kwargs
)
# get dataloader from setting
dataloader_kwargs = dataset_setting.get("dataloader_kwargs", dict())
dataloader = torch.utils.data.DataLoader(
dataset,
**dataloader_kwargs,
shuffle=False,
pin_memory=True,
drop_last=True,
collate_fn = dataset.collate_fn if hasattr(dataset,"collate_fn") else None,
)
return dataloader, loss_weight, bucket_configs
if __name__ == '__main__':
# example_config_path = 'source/dataset/example_config.yaml'
example_config_path = "configs/train_t2v_opensora_v2_ms_long32_hq400.yaml"
config = OmegaConf.load(example_config_path)
dataloader = prepare_dataloader_for_rank(config.video_training_data_config, global_rank=7, num_processes=28)
num_train_steps = 1000
progress_bar = tqdm(range(0, num_train_steps))
# output_dir = "assets/webvid-trimming_aes-tfreader"
# os.makedirs(output_dir, exist_ok=True)
# for step, batch in enumerate(tfreader):
for step, batch in enumerate(dataloader):
progress_bar.update(1)
# # save data for visualization
# pixel_values = batch['pixel_values'].cpu()
# pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w")
# for idx, pixel_value in enumerate(pixel_values):
# pixel_value = pixel_value[None, ...]
# text_value = batch['text'][idx]
# of_score = batch['of_score'][idx]
# fps_value = batch['fps'][idx]
# text_value = (text_value[:70] + '..') if len(text_value) > 70 else text_value
# output_filename = f"{output_dir}/{f'{fps_value}-{of_score}-{text_value}'}.gif"
# print(f'saving data to {output_filename}')
# save_videos_grid(pixel_value, output_filename, rescale=True)
# print(f'step: {step} / num_train_steps: {num_train_steps}')
if step >= num_train_steps:
break
|