| from copy import deepcopy |
|
|
| import colossalai |
| import torch |
| import torch.distributed as dist |
| import wandb |
| from colossalai.booster import Booster |
| from colossalai.booster.plugin import LowLevelZeroPlugin |
| from colossalai.cluster import DistCoordinator |
| from colossalai.nn.optimizer import HybridAdam |
| from colossalai.utils import get_current_device |
| from tqdm import tqdm |
|
|
| from opensora.acceleration.checkpoint import set_grad_checkpoint |
| from opensora.acceleration.parallel_states import ( |
| get_data_parallel_group, |
| set_data_parallel_group, |
| set_sequence_parallel_group, |
| ) |
| from opensora.acceleration.plugin import ZeroSeqParallelPlugin |
| from opensora.datasets import DatasetFromCSV, get_transforms_image, get_transforms_video, prepare_dataloader |
| from opensora.registry import MODELS, SCHEDULERS, build_module |
| from opensora.utils.ckpt_utils import create_logger, load, model_sharding, record_model_param_shape, save |
| from opensora.utils.config_utils import ( |
| create_experiment_workspace, |
| create_tensorboard_writer, |
| parse_configs, |
| save_training_config, |
| ) |
| from opensora.utils.misc import all_reduce_mean, format_numel_str, get_model_numel, requires_grad, to_torch_dtype |
| from opensora.utils.train_utils import update_ema |
|
|
|
|
| def main(): |
| |
| |
| |
| cfg = parse_configs(training=True) |
| print(cfg) |
| exp_name, exp_dir = create_experiment_workspace(cfg) |
| save_training_config(cfg._cfg_dict, exp_dir) |
|
|
| |
| |
| |
| assert torch.cuda.is_available(), "Training currently requires at least one GPU." |
| assert cfg.dtype in ["fp16", "bf16"], f"Unknown mixed precision {cfg.dtype}" |
|
|
| |
| colossalai.launch_from_torch({}) |
| coordinator = DistCoordinator() |
| device = get_current_device() |
| dtype = to_torch_dtype(cfg.dtype) |
|
|
| |
| if not coordinator.is_master(): |
| logger = create_logger(None) |
| else: |
| logger = create_logger(exp_dir) |
| logger.info(f"Experiment directory created at {exp_dir}") |
|
|
| writer = create_tensorboard_writer(exp_dir) |
| if cfg.wandb: |
| wandb.init(project="minisora", name=exp_name, config=cfg._cfg_dict) |
|
|
| |
| if cfg.plugin == "zero2": |
| plugin = LowLevelZeroPlugin( |
| stage=2, |
| precision=cfg.dtype, |
| initial_scale=2**16, |
| max_norm=cfg.grad_clip, |
| ) |
| set_data_parallel_group(dist.group.WORLD) |
| elif cfg.plugin == "zero2-seq": |
| plugin = ZeroSeqParallelPlugin( |
| sp_size=cfg.sp_size, |
| stage=2, |
| precision=cfg.dtype, |
| initial_scale=2**16, |
| max_norm=cfg.grad_clip, |
| ) |
| set_sequence_parallel_group(plugin.sp_group) |
| set_data_parallel_group(plugin.dp_group) |
| else: |
| raise ValueError(f"Unknown plugin {cfg.plugin}") |
| booster = Booster(plugin=plugin) |
|
|
| |
| |
| |
| dataset = DatasetFromCSV( |
| cfg.data_path, |
| |
| transform=( |
| get_transforms_video(cfg.image_size[0]) |
| if not cfg.use_image_transform |
| else get_transforms_image(cfg.image_size[0]) |
| ), |
| num_frames=cfg.num_frames, |
| frame_interval=cfg.frame_interval, |
| root=cfg.root, |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| dataloader = prepare_dataloader( |
| dataset, |
| batch_size=cfg.batch_size, |
| num_workers=cfg.num_workers, |
| shuffle=True, |
| drop_last=True, |
| pin_memory=True, |
| process_group=get_data_parallel_group(), |
| ) |
| logger.info(f"Dataset contains {len(dataset):,} videos ({cfg.data_path})") |
|
|
| total_batch_size = cfg.batch_size * dist.get_world_size() // cfg.sp_size |
| logger.info(f"Total batch size: {total_batch_size}") |
|
|
| |
| |
| |
| |
| input_size = (cfg.num_frames, *cfg.image_size) |
| vae = build_module(cfg.vae, MODELS) |
| latent_size = vae.get_latent_size(input_size) |
| text_encoder = build_module(cfg.text_encoder, MODELS, device=device) |
| model = build_module( |
| cfg.model, |
| MODELS, |
| input_size=latent_size, |
| in_channels=vae.out_channels, |
| caption_channels=text_encoder.output_dim, |
| model_max_length=text_encoder.model_max_length, |
| dtype=dtype, |
| ) |
| model_numel, model_numel_trainable = get_model_numel(model) |
| logger.info( |
| f"Trainable model params: {format_numel_str(model_numel_trainable)}, Total model params: {format_numel_str(model_numel)}" |
| ) |
|
|
| |
| ema = deepcopy(model).to(torch.float32).to(device) |
| requires_grad(ema, False) |
| ema_shape_dict = record_model_param_shape(ema) |
|
|
| |
| vae = vae.to(device, dtype) |
| model = model.to(device, dtype) |
|
|
| |
| scheduler = build_module(cfg.scheduler, SCHEDULERS) |
|
|
| |
| optimizer = HybridAdam( |
| filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.lr, weight_decay=0, adamw_mode=True |
| ) |
| lr_scheduler = None |
|
|
| |
| if cfg.grad_checkpoint: |
| set_grad_checkpoint(model) |
| model.train() |
| update_ema(ema, model, decay=0, sharded=False) |
| ema.eval() |
|
|
| |
| |
| |
| torch.set_default_dtype(dtype) |
| model, optimizer, _, dataloader, lr_scheduler = booster.boost( |
| model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, dataloader=dataloader |
| ) |
| torch.set_default_dtype(torch.float) |
| num_steps_per_epoch = len(dataloader) |
| logger.info("Boost model for distributed training") |
|
|
| |
| |
| |
| start_epoch = start_step = log_step = sampler_start_idx = 0 |
| running_loss = 0.0 |
|
|
| |
| if cfg.load is not None: |
| logger.info("Loading checkpoint") |
| start_epoch, start_step, sampler_start_idx = load(booster, model, ema, optimizer, lr_scheduler, cfg.load) |
| logger.info(f"Loaded checkpoint {cfg.load} at epoch {start_epoch} step {start_step}") |
| logger.info(f"Training for {cfg.epochs} epochs with {num_steps_per_epoch} steps per epoch") |
|
|
| dataloader.sampler.set_start_index(sampler_start_idx) |
| model_sharding(ema) |
|
|
| |
| for epoch in range(start_epoch, cfg.epochs): |
| dataloader.sampler.set_epoch(epoch) |
| dataloader_iter = iter(dataloader) |
| logger.info(f"Beginning epoch {epoch}...") |
|
|
| with tqdm( |
| range(start_step, num_steps_per_epoch), |
| desc=f"Epoch {epoch}", |
| disable=not coordinator.is_master(), |
| total=num_steps_per_epoch, |
| initial=start_step, |
| ) as pbar: |
| for step in pbar: |
| batch = next(dataloader_iter) |
| x = batch["video"].to(device, dtype) |
| y = batch["text"] |
|
|
| with torch.no_grad(): |
| |
| x = vae.encode(x) |
| |
| model_args = text_encoder.encode(y) |
|
|
| |
| t = torch.randint(0, scheduler.num_timesteps, (x.shape[0],), device=device) |
| loss_dict = scheduler.training_losses(model, x, t, model_args) |
|
|
| |
| loss = loss_dict["loss"].mean() |
| booster.backward(loss=loss, optimizer=optimizer) |
| optimizer.step() |
| optimizer.zero_grad() |
|
|
| |
| update_ema(ema, model.module, optimizer=optimizer) |
|
|
| |
| all_reduce_mean(loss) |
| running_loss += loss.item() |
| global_step = epoch * num_steps_per_epoch + step |
| log_step += 1 |
|
|
| |
| if coordinator.is_master() and (global_step + 1) % cfg.log_every == 0: |
| avg_loss = running_loss / log_step |
| pbar.set_postfix({"loss": avg_loss, "step": step, "global_step": global_step}) |
| running_loss = 0 |
| log_step = 0 |
| writer.add_scalar("loss", loss.item(), global_step) |
| if cfg.wandb: |
| wandb.log( |
| { |
| "iter": global_step, |
| "num_samples": global_step * total_batch_size, |
| "epoch": epoch, |
| "loss": loss.item(), |
| "avg_loss": avg_loss, |
| }, |
| step=global_step, |
| ) |
|
|
| |
| if cfg.ckpt_every > 0 and (global_step + 1) % cfg.ckpt_every == 0: |
| save( |
| booster, |
| model, |
| ema, |
| optimizer, |
| lr_scheduler, |
| epoch, |
| step + 1, |
| global_step + 1, |
| cfg.batch_size, |
| coordinator, |
| exp_dir, |
| ema_shape_dict, |
| ) |
| logger.info( |
| f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}" |
| ) |
|
|
| |
| dataloader.sampler.set_start_index(0) |
| start_step = 0 |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|