|
|
import sys |
|
|
import os |
|
|
import importlib |
|
|
|
|
|
from omegaconf import OmegaConf |
|
|
from tqdm.auto import tqdm |
|
|
|
|
|
import torch |
|
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__),'../..')) |
|
|
|
|
|
|
|
|
|
|
|
def get_obj_from_str(string, reload=False, invalidate_cache=True): |
|
|
module, cls = string.rsplit(".", 1) |
|
|
if invalidate_cache: |
|
|
importlib.invalidate_caches() |
|
|
if reload: |
|
|
module_imp = importlib.import_module(module) |
|
|
importlib.reload(module_imp) |
|
|
return getattr(importlib.import_module(module, package=None), cls) |
|
|
|
|
|
|
|
|
def prepare_dataloader_for_rank(config, global_rank, num_processes=-1, repeat_cp_size=1): |
|
|
""" Get the dataloader given config and the current global rank. |
|
|
"dataset_setting" provides the list of dataset configs |
|
|
"rank_index_map" provides how to distribute the config across ranks |
|
|
""" |
|
|
|
|
|
if repeat_cp_size > 1: |
|
|
print(f'before repeat config.rank_index_map: {config.rank_index_map}') |
|
|
repeated_rank_index_map = [element for element in config.rank_index_map for _ in range(repeat_cp_size)] |
|
|
config.rank_index_map = repeated_rank_index_map |
|
|
print(f'after repeat repeated_rank_index_map: {config.rank_index_map}') |
|
|
|
|
|
|
|
|
num_total_indices = len(config.rank_index_map) |
|
|
dataset_index = config.rank_index_map[global_rank % num_total_indices] |
|
|
|
|
|
|
|
|
num_partitions = 1 |
|
|
partition_id = 0 |
|
|
if num_processes > 0: |
|
|
rank_to_dataset_index_map = list(config.rank_index_map) * num_processes |
|
|
rank_to_dataset_index_map = rank_to_dataset_index_map[:num_processes] |
|
|
num_partitions = rank_to_dataset_index_map.count(dataset_index) |
|
|
partition_id = rank_to_dataset_index_map[:global_rank].count(dataset_index) |
|
|
print(f'rank_to_dataset_index_map: {rank_to_dataset_index_map}') |
|
|
print(f'dataset_index: {dataset_index} partition_id: {partition_id} num_partitions: {num_partitions} ') |
|
|
|
|
|
|
|
|
sum_loss_weight = 0.0 |
|
|
for i in range(num_total_indices): |
|
|
dataset_setting = config.dataset_setting[config.rank_index_map[i]] |
|
|
sum_loss_weight += dataset_setting.get("loss_weight", 1.0) |
|
|
loss_weight_scale = float(num_total_indices) / sum_loss_weight |
|
|
|
|
|
|
|
|
dataset_setting = config.dataset_setting[dataset_index] |
|
|
loss_weight = dataset_setting.get("loss_weight", 1.0) * loss_weight_scale |
|
|
print(f'global_rank: {global_rank} -- dataset_index: {dataset_index} - loss_weight_scale: {loss_weight_scale} - loss weight: {loss_weight} - dataset_setting: {dataset_setting}') |
|
|
|
|
|
|
|
|
utils_prompt_module = importlib.import_module(dataset_setting.get_prompt_module) |
|
|
get_prompt_func = getattr(utils_prompt_module, dataset_setting.get_prompt_func) |
|
|
get_prompt_frame_spans_func = None |
|
|
if hasattr(dataset_setting, "get_prompt_frame_spans_func"): |
|
|
get_prompt_frame_spans_func = getattr(utils_prompt_module, dataset_setting.get_prompt_frame_spans_func) |
|
|
|
|
|
|
|
|
dataset_kwargs = dataset_setting.get("dataset_kwargs", dict()) |
|
|
|
|
|
|
|
|
assert hasattr(dataset_kwargs, "bucket_configs") |
|
|
bucket_configs = dataset_kwargs.get("bucket_configs", dict()) |
|
|
|
|
|
dataset = get_obj_from_str(dataset_setting.dataset_target)( |
|
|
get_prompt_func=get_prompt_func, |
|
|
get_prompt_frame_spans_func=get_prompt_frame_spans_func, |
|
|
partition_id=partition_id, |
|
|
num_partitions=num_partitions, |
|
|
**dataset_kwargs |
|
|
) |
|
|
|
|
|
|
|
|
dataloader_kwargs = dataset_setting.get("dataloader_kwargs", dict()) |
|
|
dataloader = torch.utils.data.DataLoader( |
|
|
dataset, |
|
|
**dataloader_kwargs, |
|
|
shuffle=False, |
|
|
pin_memory=True, |
|
|
drop_last=True, |
|
|
collate_fn = dataset.collate_fn if hasattr(dataset,"collate_fn") else None, |
|
|
) |
|
|
|
|
|
return dataloader, loss_weight, bucket_configs |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
example_config_path = "configs/train_t2v_opensora_v2_ms_long32_hq400.yaml" |
|
|
config = OmegaConf.load(example_config_path) |
|
|
|
|
|
dataloader = prepare_dataloader_for_rank(config.video_training_data_config, global_rank=7, num_processes=28) |
|
|
|
|
|
num_train_steps = 1000 |
|
|
progress_bar = tqdm(range(0, num_train_steps)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for step, batch in enumerate(dataloader): |
|
|
progress_bar.update(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if step >= num_train_steps: |
|
|
break |
|
|
|