|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
import logging |
|
|
import math |
|
|
import os |
|
|
import random |
|
|
import shutil |
|
|
from pathlib import Path |
|
|
import einops |
|
|
|
|
|
import signal |
|
|
import sys |
|
|
|
|
|
import accelerate |
|
|
import numpy as np |
|
|
import imageio |
|
|
import PIL |
|
|
from PIL import Image, ImageDraw |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import torch.utils.checkpoint |
|
|
from torch.distributed.fsdp.fully_sharded_data_parallel import ( |
|
|
FullOptimStateDictConfig, FullStateDictConfig |
|
|
) |
|
|
|
|
|
import transformers |
|
|
from accelerate import Accelerator, DistributedDataParallelKwargs, AutocastKwargs |
|
|
from accelerate import FullyShardedDataParallelPlugin, DeepSpeedPlugin |
|
|
from accelerate.logging import get_logger |
|
|
from accelerate.state import AcceleratorState |
|
|
from accelerate.utils import ProjectConfiguration, set_seed, DistributedType |
|
|
from huggingface_hub import create_repo, upload_folder |
|
|
from packaging import version |
|
|
from torchvision import transforms |
|
|
from tqdm.auto import tqdm |
|
|
from einops import rearrange, repeat |
|
|
|
|
|
from transformers.utils import ContextManagers |
|
|
|
|
|
from typing import Dict, Optional, Tuple, List, Union |
|
|
from omegaconf import OmegaConf, ListConfig |
|
|
from dataclasses import dataclass, asdict |
|
|
|
|
|
import diffusers |
|
|
|
|
|
from diffusers.optimization import get_scheduler |
|
|
from diffusers.training_utils import EMAModel |
|
|
from diffusers.utils import check_min_version, is_wandb_available |
|
|
from diffusers.utils.torch_utils import is_compiled_module |
|
|
|
|
|
from peft import LoraConfig, get_peft_model_state_dict |
|
|
|
|
|
from src.utils.random_state_utils import save_random_state |
|
|
from src.models.recon.model_latent_recon import LatentRecon |
|
|
|
|
|
from kiui.lpips import LPIPS |
|
|
from fused_ssim import fused_ssim |
|
|
|
|
|
from src.models.data import get_multi_dataloader |
|
|
from src.models.utils.model import encode_latent_time_vae, encode_plucker_vae, repeat_time_spatially |
|
|
from src.models.utils.cosmos_1_tokenizer import load_cosmos_1_tokenizer |
|
|
from src.models.utils.render import get_plucker_embedding_and_rays |
|
|
from src.models.utils.misc import dtype_map |
|
|
from src.models.utils.model import encode_multi_view_video, load_vae, encode_video |
|
|
from src.models.utils.loss import compute_loss |
|
|
from src.models.utils.train import get_most_recent_checkpoint |
|
|
import time |
|
|
|
|
|
if is_wandb_available(): |
|
|
import wandb |
|
|
|
|
|
|
|
|
check_min_version("0.30.3") |
|
|
|
|
|
logger = get_logger(__name__, log_level="INFO") |
|
|
|
|
|
def prepare_config( |
|
|
config: Dict |
|
|
): |
|
|
env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) |
|
|
if env_local_rank != -1 and env_local_rank != config.local_rank: |
|
|
config.local_rank = env_local_rank |
|
|
|
|
|
|
|
|
config.model_pipeline = config.get('model_pipeline', {}) |
|
|
if config.model_pipeline.get('vae_path', None) is None: |
|
|
config.model_pipeline['vae_path'] = config.pretrained_model_name_or_path |
|
|
config.model_pipeline['use_lora'] = config.model_pipeline.get('use_lora', False) |
|
|
|
|
|
def load_model_weights(path_ckpt, transformer): |
|
|
path_ckpt_model = os.path.join(path_ckpt, 'pytorch_model', 'mp_rank_00_model_states.pt') |
|
|
model_state = torch.load(path_ckpt_model, map_location="cpu") |
|
|
model_state = {f'module.{k}': v for k, v in model_state['module'].items()} |
|
|
transformer.load_state_dict(model_state, strict=False) |
|
|
|
|
|
def resume_from_ckpt(config, accelerator, transformer): |
|
|
global_step = 0 |
|
|
first_epoch = 0 |
|
|
loaded_accelerator = False |
|
|
|
|
|
if config.resume_from_checkpoint: |
|
|
is_latest_resume = False |
|
|
if config.resume_from_checkpoint_dir is not None: |
|
|
path = os.path.basename(config.resume_from_checkpoint) |
|
|
path_ckpt = os.path.join(config.resume_from_checkpoint_dir, config.resume_from_checkpoint) |
|
|
else: |
|
|
if config.resume_from_checkpoint != "latest": |
|
|
path = os.path.basename(config.resume_from_checkpoint) |
|
|
else: |
|
|
|
|
|
path = get_most_recent_checkpoint(config.output_dir) |
|
|
if path is None: |
|
|
|
|
|
logger.warning(f"No latest resume checkpoint found, assuming this is our first training session!") |
|
|
else: |
|
|
is_latest_resume = True |
|
|
if path is not None: |
|
|
path_ckpt = os.path.join(config.output_dir, path) |
|
|
else: |
|
|
path_ckpt = None |
|
|
if path_ckpt is None: |
|
|
accelerator.print( |
|
|
f"Checkpoint '{config.resume_from_checkpoint}' does not exist. Starting a new training run." |
|
|
) |
|
|
config.resume_from_checkpoint = None |
|
|
initial_global_step = 0 |
|
|
else: |
|
|
accelerator.print(f"Resuming from checkpoint {path_ckpt}") |
|
|
try: |
|
|
accelerator.load_state(path_ckpt) |
|
|
loaded_accelerator = True |
|
|
except Exception as e: |
|
|
|
|
|
print("Failed to load checkpoint: Try to only load model weights") |
|
|
try: |
|
|
load_model_weights(path_ckpt, transformer) |
|
|
print("Loaded only model weights") |
|
|
except: |
|
|
logger.warning(f"Failed to load checkpoint: {e}") |
|
|
if is_latest_resume: |
|
|
logger.warning("Remove the broken checkpoint and exit.") |
|
|
if accelerator.is_main_process: |
|
|
|
|
|
if path.endswith("bkup"): |
|
|
logger.warning("Debug NOT removing the broken checkpoint.") |
|
|
else: |
|
|
shutil.rmtree(path_ckpt) |
|
|
|
|
|
exit(1) |
|
|
|
|
|
global_step = int(path.split("-")[1]) |
|
|
initial_global_step = global_step |
|
|
first_epoch = 0 |
|
|
else: |
|
|
initial_global_step = 0 |
|
|
return initial_global_step, global_step, first_epoch, loaded_accelerator |
|
|
|
|
|
def main( |
|
|
config: Dict, |
|
|
wandb_run_name, |
|
|
wandb_group_name, |
|
|
app_start_time, |
|
|
): |
|
|
prepare_config(config) |
|
|
logging_dir = os.path.join(config.output_dir, config.logging_dir) |
|
|
|
|
|
accelerator_project_config = ProjectConfiguration(project_dir=config.output_dir, logging_dir=logging_dir) |
|
|
find_unused_parameters = ( |
|
|
(config.gradient_accumulation_steps > 1) and |
|
|
(config.model_pipeline.get('unet_trainable_modules', None) is not None) |
|
|
) |
|
|
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=find_unused_parameters) |
|
|
autocast_kwargs = AutocastKwargs(cache_enabled=config.autocast_cache_enabled) |
|
|
|
|
|
if config.use_fsdp: |
|
|
fsdp_plugin = FullyShardedDataParallelPlugin( |
|
|
state_dict_config=FullStateDictConfig(offload_to_cpu=False, rank0_only=False), |
|
|
optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=False, rank0_only=False), |
|
|
use_orig_params=True, |
|
|
) |
|
|
fsdp_plugin.use_orig_params = True |
|
|
assert not config.use_ema, "FSDP does not support EMAModel yet, please consider DeepSpeed" |
|
|
else: |
|
|
fsdp_plugin = None |
|
|
|
|
|
if config.use_deepspeed: |
|
|
deepspeed_plugin = DeepSpeedPlugin( |
|
|
gradient_accumulation_steps=config.gradient_accumulation_steps, |
|
|
zero_stage=2, |
|
|
gradient_clipping=config.max_grad_norm |
|
|
) |
|
|
deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = config.batch_size |
|
|
if config.deepspeed_type is None: |
|
|
config.deepspeed_type = config.mixed_precision |
|
|
if config.deepspeed_type == 'fp16': |
|
|
deepspeed_plugin.deepspeed_config['fp16'] = { |
|
|
"enabled": 'auto', |
|
|
"auto_cast": True, |
|
|
"initial_scale_power": 16, |
|
|
} |
|
|
elif config.deepspeed_type == 'bf16': |
|
|
deepspeed_plugin.deepspeed_config['bf16'] = { |
|
|
"enabled": True, |
|
|
} |
|
|
else: |
|
|
deepspeed_plugin = None |
|
|
|
|
|
accelerator = Accelerator( |
|
|
gradient_accumulation_steps=config.gradient_accumulation_steps, |
|
|
mixed_precision=config.mixed_precision, |
|
|
log_with=config.log_with, |
|
|
project_config=accelerator_project_config, |
|
|
fsdp_plugin=fsdp_plugin, |
|
|
deepspeed_plugin=deepspeed_plugin, |
|
|
kwargs_handlers=[ddp_kwargs, autocast_kwargs] |
|
|
) |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
|
level=logging.INFO, |
|
|
) |
|
|
logger.info(accelerator.state, main_process_only=False) |
|
|
if accelerator.is_local_main_process: |
|
|
transformers.utils.logging.set_verbosity_warning() |
|
|
diffusers.utils.logging.set_verbosity_info() |
|
|
else: |
|
|
transformers.utils.logging.set_verbosity_error() |
|
|
diffusers.utils.logging.set_verbosity_error() |
|
|
|
|
|
|
|
|
if config.seed is not None: |
|
|
|
|
|
set_seed(config.seed) |
|
|
else: |
|
|
print("Not setting a seed") |
|
|
|
|
|
|
|
|
if accelerator.is_main_process: |
|
|
if config.output_dir is not None: |
|
|
os.makedirs(config.output_dir, exist_ok=True) |
|
|
OmegaConf.save(config, os.path.join(config.output_dir, "config.yaml")) |
|
|
|
|
|
def deepspeed_zero_init_disabled_context_manager(): |
|
|
""" |
|
|
returns either a context list that includes one that will disable zero.Init or an empty context list |
|
|
""" |
|
|
deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None |
|
|
if deepspeed_plugin is None: |
|
|
return [] |
|
|
|
|
|
return [deepspeed_plugin.zero3_init_context_manager(enable=False)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vae = None |
|
|
with ContextManagers(deepspeed_zero_init_disabled_context_manager()): |
|
|
vae = load_vae(config.vae_backbone, config.vae_path) |
|
|
|
|
|
for k, v in config.items(): |
|
|
if isinstance(v, list) or isinstance(v, ListConfig): |
|
|
config[k] = tuple(v) |
|
|
config = OmegaConf.structured(config) |
|
|
|
|
|
|
|
|
|
|
|
weight_dtype = torch.float16 |
|
|
if accelerator.mixed_precision == "fp16": |
|
|
weight_dtype = torch.float16 |
|
|
config.mixed_precision = accelerator.mixed_precision |
|
|
elif accelerator.mixed_precision == "bf16": |
|
|
weight_dtype = torch.bfloat16 |
|
|
config.mixed_precision = accelerator.mixed_precision |
|
|
transformer = LatentRecon( |
|
|
config, |
|
|
) |
|
|
|
|
|
if config.lambda_lpips > 0: |
|
|
lpips_img_size = config.img_size if not isinstance(config.img_size, int) else [config.img_size, config.img_size] |
|
|
lpips_loss_module = LPIPS(net='vgg') |
|
|
lpips_loss_module.requires_grad_(False) |
|
|
lpips_loss_module = lpips_loss_module.to(accelerator.device) |
|
|
else: |
|
|
lpips_loss_module = None |
|
|
if config.resume_pretrained_model_ckpt: |
|
|
|
|
|
logger.info(f"Loading pretrain ckpt from: {config.resume_pretrained_model_ckpt}") |
|
|
data = torch.load(config.resume_pretrained_model_ckpt) |
|
|
transformer.load_state_dict(data["module"]) |
|
|
|
|
|
|
|
|
transformer.train() |
|
|
transformer.requires_grad_(False) |
|
|
if config.set_transformer_dtype: |
|
|
for module in transformer.modules(): |
|
|
module.to(accelerator.device, dtype=weight_dtype) |
|
|
modules_dtype = [vae] |
|
|
for module in modules_dtype: |
|
|
if module is not None: |
|
|
module.requires_grad_(False) |
|
|
module.to(accelerator.device, dtype=weight_dtype) |
|
|
if config.compile_frozen_modules: |
|
|
vae.encode = torch.compile(vae.encode) |
|
|
|
|
|
|
|
|
lora_params = [] |
|
|
if config.model_pipeline['use_lora']: |
|
|
transformer_lora_config = LoraConfig( |
|
|
r=config.model_pipeline.get('lora_rank', 64), |
|
|
lora_alpha=config.model_pipeline.get('lora_alpha', 64), |
|
|
init_lora_weights=True, |
|
|
target_modules=["to_k", "to_q", "to_v", "to_out.0"], |
|
|
) |
|
|
transformer.add_adapter(transformer_lora_config) |
|
|
lora_params = [name for name, p in transformer.named_parameters() if p.requires_grad] |
|
|
|
|
|
|
|
|
def unwrap_model(model): |
|
|
model = accelerator.unwrap_model(model) |
|
|
model = model._orig_mod if is_compiled_module(model) else model |
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): |
|
|
|
|
|
def save_model_hook(models, weights, output_dir): |
|
|
if accelerator.is_main_process: |
|
|
if config.use_ema: |
|
|
ema_transformer.save_pretrained(os.path.join(output_dir, "transformer_ema")) |
|
|
|
|
|
for i, model in enumerate(models): |
|
|
model.save_pretrained(os.path.join(output_dir, "transformer")) |
|
|
|
|
|
if config.model_pipeline['use_lora']: |
|
|
transformer_lora_layers_to_save = get_peft_model_state_dict(model) |
|
|
model.save_lora_weights(os.path.join(output_dir, "lora"), transformer_lora_layers=transformer_lora_layers_to_save) |
|
|
|
|
|
|
|
|
weights.pop() |
|
|
|
|
|
def load_model_hook(models, input_dir): |
|
|
for _ in range(len(models)): |
|
|
|
|
|
model = models.pop() |
|
|
|
|
|
|
|
|
load_model = GSLRMLatent.from_pretrained(input_dir, subfolder="transformer") |
|
|
model.register_to_config(**load_model.config) |
|
|
|
|
|
model.load_state_dict(load_model.state_dict()) |
|
|
del load_model |
|
|
|
|
|
if accelerator.distributed_type not in [DistributedType.FSDP, DistributedType.DEEPSPEED]: |
|
|
accelerator.register_save_state_pre_hook(save_model_hook) |
|
|
accelerator.register_load_state_pre_hook(load_model_hook) |
|
|
|
|
|
|
|
|
|
|
|
if config.allow_tf32: |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
|
|
|
if config.scale_lr: |
|
|
config.learning_rate = ( |
|
|
config.learning_rate * config.gradient_accumulation_steps * config.batch_size * accelerator.num_processes |
|
|
) |
|
|
|
|
|
|
|
|
if config.use_8bit_adam: |
|
|
try: |
|
|
import bitsandbytes as bnb |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" |
|
|
) |
|
|
|
|
|
optimizer_cls = bnb.optim.AdamW8bit |
|
|
else: |
|
|
optimizer_cls = torch.optim.AdamW |
|
|
|
|
|
parameters_list = [] |
|
|
param_names = [] |
|
|
|
|
|
|
|
|
transformer_trainable_modules = config.model_pipeline.get('transformer_trainable_modules', None) |
|
|
if transformer_trainable_modules is not None: |
|
|
|
|
|
for name, param in transformer.named_parameters(): |
|
|
for module in transformer_trainable_modules: |
|
|
|
|
|
if module in name: |
|
|
parameters_list.append(param) |
|
|
param_names.append(name) |
|
|
param.requires_grad = True |
|
|
break |
|
|
|
|
|
for name, param in transformer.named_parameters(): |
|
|
if name in lora_params and name not in param_names: |
|
|
parameters_list.append(param) |
|
|
param_names.append(name) |
|
|
param.requires_grad = True |
|
|
else: |
|
|
transformer.requires_grad_(True) |
|
|
parameters_list = transformer.parameters() |
|
|
param_names = [name for name, param in transformer.named_parameters()] |
|
|
|
|
|
|
|
|
if accelerator.distributed_type == DistributedType.FSDP: |
|
|
transformer = accelerator.prepare(transformer) |
|
|
logger.info("***** Parameters list *****") |
|
|
logger.info(f"{param_names}") |
|
|
optimizer = optimizer_cls( |
|
|
parameters_list, |
|
|
lr=config.learning_rate, |
|
|
betas=(config.adam_beta1, config.adam_beta2), |
|
|
weight_decay=config.adam_weight_decay, |
|
|
eps=config.adam_epsilon, |
|
|
) |
|
|
|
|
|
|
|
|
global_batch_size = config.batch_size * accelerator.num_processes |
|
|
|
|
|
overrode_max_train_steps = False |
|
|
|
|
|
lr_scheduler = get_scheduler( |
|
|
config.lr_scheduler, |
|
|
optimizer=optimizer, |
|
|
num_warmup_steps=config.lr_warmup_steps * accelerator.num_processes, |
|
|
num_training_steps=config.max_train_steps * accelerator.num_processes, |
|
|
) |
|
|
|
|
|
|
|
|
if accelerator.distributed_type == DistributedType.FSDP: |
|
|
optimizer, lr_scheduler = accelerator.prepare(optimizer, lr_scheduler) |
|
|
else: |
|
|
transformer, optimizer, lr_scheduler = accelerator.prepare(transformer, optimizer, lr_scheduler) |
|
|
|
|
|
|
|
|
|
|
|
if accelerator.is_main_process: |
|
|
tracker_config = dict(vars(config)) |
|
|
pop_keys = [] |
|
|
for k, v in tracker_config.items(): |
|
|
if v is not None and not isinstance(v, (int, float, str, bool, torch.Tensor)): |
|
|
pop_keys.append(k) |
|
|
for k in pop_keys: |
|
|
tracker_config.pop(k) |
|
|
|
|
|
init_kwargs = { |
|
|
"wandb": { |
|
|
"name": wandb_run_name, |
|
|
"dir": config.output_dir, |
|
|
"group": wandb_group_name, |
|
|
"tags": ["cosmos_3dgs"], |
|
|
"resume": "auto", |
|
|
}, |
|
|
} |
|
|
|
|
|
accelerator.init_trackers(config.experiment_name, config=tracker_config, init_kwargs=init_kwargs) |
|
|
|
|
|
|
|
|
global_step = 0 |
|
|
first_epoch = 0 |
|
|
|
|
|
initial_global_step, global_step, first_epoch, loaded_accelerator = resume_from_ckpt(config, accelerator, transformer) |
|
|
|
|
|
if config.lr_overwrite: |
|
|
print(f"Set new optimizer with learning rate {config.learning_rate}") |
|
|
for param_group in optimizer.param_groups: |
|
|
param_group['lr'] = config.learning_rate |
|
|
lr_scheduler = get_scheduler( |
|
|
config.lr_scheduler, |
|
|
optimizer=optimizer, |
|
|
num_warmup_steps=config.lr_warmup_steps * accelerator.num_processes, |
|
|
num_training_steps=config.max_train_steps * accelerator.num_processes, |
|
|
) |
|
|
|
|
|
|
|
|
if (initial_global_step == 0 or not loaded_accelerator) and config.seed is not None: |
|
|
set_seed(config.seed, device_specific=True) |
|
|
print(f"Set seed to {config.seed}") |
|
|
|
|
|
|
|
|
wds_loader = True |
|
|
train_dataloader, test_dataloader = get_multi_dataloader(config, accelerator) |
|
|
|
|
|
def train_step(batch, train_loss, num_input_multi_views): |
|
|
threedgs_kwargs = config |
|
|
|
|
|
if wds_loader: |
|
|
batch_keys = list(batch.keys()) |
|
|
for k in batch_keys: |
|
|
if isinstance(batch[k], torch.Tensor): |
|
|
batch[k] = batch[k].to(accelerator.device, non_blocking=True) |
|
|
|
|
|
if k not in ['intrinsics_input', 'c2ws_input', 'cam_view', 'intrinsics']: |
|
|
batch[k] = batch[k].to(weight_dtype) |
|
|
|
|
|
gt_images = batch['images_output'] |
|
|
gt_depths = batch.get('depths_output', None) |
|
|
|
|
|
|
|
|
if 'num_input_multi_views' in batch: |
|
|
assert (batch['num_input_multi_views'][0] == batch['num_input_multi_views']).all(), f"Not supporting multi batch size for variable multi-view" |
|
|
num_input_multi_views = int(batch['num_input_multi_views'][0].item()) |
|
|
batch['num_input_multi_views'] = num_input_multi_views |
|
|
|
|
|
|
|
|
if 'rgb_latents' in batch: |
|
|
model_input = batch['rgb_latents'].to(weight_dtype) |
|
|
else: |
|
|
video = batch['images_input_vae'].to(weight_dtype) |
|
|
if threedgs_kwargs.use_rgb_decoder: |
|
|
model_input = video |
|
|
else: |
|
|
|
|
|
model_input = encode_multi_view_video(vae, video, num_input_multi_views, config.vae_backbone) |
|
|
batch['images_input_embed'] = model_input |
|
|
|
|
|
|
|
|
if threedgs_kwargs.get('compute_plucker_cuda', True): |
|
|
batch['plucker_embedding'], batch['rays_os'], batch['rays_ds'] = get_plucker_embedding_and_rays( |
|
|
batch['intrinsics_input'], |
|
|
batch['c2ws_input'], |
|
|
threedgs_kwargs.img_size, |
|
|
threedgs_kwargs.patch_size_out_factor, |
|
|
batch['flip_flag'], |
|
|
get_batch_index=False, |
|
|
dtype=dtype_map[threedgs_kwargs.compute_plucker_dtype], |
|
|
out_dtype=weight_dtype |
|
|
) |
|
|
|
|
|
|
|
|
if threedgs_kwargs.get('use_time_embedding', False) and threedgs_kwargs.get('time_embedding_vae', False): |
|
|
batch = encode_latent_time_vae(batch, lambda x: encode_video(vae, x, config.vae_backbone), threedgs_kwargs.img_size) |
|
|
if threedgs_kwargs.get('plucker_embedding_vae', False): |
|
|
batch = encode_plucker_vae(batch, lambda x: encode_multi_view_video(vae, x, num_input_multi_views, config.vae_backbone)) |
|
|
|
|
|
|
|
|
model_output = transformer(batch) |
|
|
|
|
|
|
|
|
pred_images = model_output['images_pred'].to(gt_images.dtype) |
|
|
pred_depths = model_output['depths_pred'].to(gt_images.dtype) |
|
|
pred_opacity = model_output['opacity_pred'] |
|
|
train_loss, loss = compute_loss(accelerator, train_loss, pred_images, gt_images, pred_depths, gt_depths, pred_opacity, config, lpips_loss_module, lpips_img_size) |
|
|
|
|
|
|
|
|
accelerator.backward(loss) |
|
|
if accelerator.sync_gradients: |
|
|
accelerator.clip_grad_norm_(transformer.parameters(), config.max_grad_norm) |
|
|
|
|
|
if optimizer.scaler is not None: |
|
|
optimizer.scaler._check_inf_per_device(optimizer.optimizer) |
|
|
optimizer.step() |
|
|
lr_scheduler.step() |
|
|
optimizer.zero_grad() |
|
|
return train_loss |
|
|
|
|
|
|
|
|
total_batch_size = config.batch_size * accelerator.num_processes * config.gradient_accumulation_steps |
|
|
|
|
|
logger.info("***** Running training *****") |
|
|
logger.info(f" Instantaneous batch size per device = {config.batch_size}") |
|
|
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") |
|
|
logger.info(f" Gradient Accumulation steps = {config.gradient_accumulation_steps}") |
|
|
logger.info(f" Total optimization steps = {config.max_train_steps}") |
|
|
logger.info(f" Output dir: {config.output_dir}") |
|
|
|
|
|
progress_bar = tqdm( |
|
|
range(0, config.max_train_steps), |
|
|
initial=initial_global_step, |
|
|
desc="Steps", |
|
|
|
|
|
disable=not accelerator.is_local_main_process, |
|
|
) |
|
|
break_loop = False |
|
|
while True: |
|
|
train_loss = 0.0 |
|
|
for step, batch in enumerate(train_dataloader): |
|
|
with accelerator.accumulate(transformer): |
|
|
train_loss = train_step(batch, train_loss, config.num_input_multi_views) |
|
|
|
|
|
|
|
|
if accelerator.sync_gradients: |
|
|
progress_bar.update(1) |
|
|
global_step += 1 |
|
|
accelerator.log({ |
|
|
"train_loss": train_loss, |
|
|
"lr": lr_scheduler.get_last_lr()[0], |
|
|
}, step=global_step) |
|
|
train_loss = 0.0 |
|
|
|
|
|
if global_step % config.checkpointing_steps == 0: |
|
|
if accelerator.is_main_process: |
|
|
|
|
|
if config.checkpoints_total_limit is not None: |
|
|
checkpoints = os.listdir(config.output_dir) |
|
|
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] |
|
|
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) |
|
|
|
|
|
|
|
|
checkpoints = [ckpt for ckpt in checkpoints if |
|
|
(int(ckpt.split("-")[1]) % config.permanent_checkpointing_steps) != 0] |
|
|
print(checkpoints) |
|
|
|
|
|
|
|
|
if len(checkpoints) >= config.checkpoints_total_limit: |
|
|
num_to_remove = len(checkpoints) - config.checkpoints_total_limit + 1 |
|
|
removing_checkpoints = checkpoints[0:num_to_remove] |
|
|
|
|
|
logger.info( |
|
|
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" |
|
|
) |
|
|
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") |
|
|
|
|
|
for removing_checkpoint in removing_checkpoints: |
|
|
removing_checkpoint = os.path.join(config.output_dir, removing_checkpoint) |
|
|
shutil.rmtree(removing_checkpoint) |
|
|
|
|
|
save_path = os.path.join(config.output_dir, f"checkpoint-{global_step}") |
|
|
if accelerator.is_main_process or accelerator.distributed_type in [DistributedType.FSDP, DistributedType.DEEPSPEED]: |
|
|
accelerator.save_state(save_path) |
|
|
logger.info(f"Saved state to {save_path}") |
|
|
|
|
|
if config.save_multi_random_states and not accelerator.is_main_process: |
|
|
save_path = os.path.join(config.output_dir, f"checkpoint-{global_step}") |
|
|
os.makedirs(save_path, exist_ok=True) |
|
|
save_random_state(save_path, accelerator.process_index) |
|
|
|
|
|
if config.job_stop_steps is not None and global_step % config.job_stop_steps == 0: |
|
|
logger.info('Reach Job Stop Steps') |
|
|
break_loop = True |
|
|
break |
|
|
|
|
|
logs = {"step_loss": train_loss, "lr": lr_scheduler.get_last_lr()[0], "dir": config.output_dir} |
|
|
if optimizer.step_was_skipped: |
|
|
logs["overflow"] = 1 |
|
|
logs["scaler"] = optimizer.scaler._scale.item() if optimizer.scaler is not None else 1 |
|
|
logger.warning(f"Gradient overflow. Skipping step {global_step}, scaler {logs['scaler']}") |
|
|
progress_bar.set_postfix(**logs) |
|
|
|
|
|
if global_step >= config.max_train_steps: |
|
|
logger.info('Reach Max Train Steps') |
|
|
break_loop = True |
|
|
break |
|
|
|
|
|
if break_loop: |
|
|
break |
|
|
|
|
|
|
|
|
accelerator.wait_for_everyone() |
|
|
accelerator.end_training() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
app_start_time = time.time_ns() / 1_000_000 |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--config', type=str, required=True) |
|
|
parser.add_argument('--config_default', type=str, default='configs/training/default.yaml') |
|
|
parser.add_argument('--wandb_run_name', type=str, default=None, help="Name of run for wandb",) |
|
|
parser.add_argument('--wandb_group_name', type=str, default=None) |
|
|
args, unknown = parser.parse_known_args() |
|
|
schema = OmegaConf.load(args.config_default) |
|
|
config = OmegaConf.load(args.config) |
|
|
missing_keys = set(config.keys()) - set(schema.keys()) |
|
|
for key in missing_keys: |
|
|
OmegaConf.update(schema, key, None, force_add=True) |
|
|
config = OmegaConf.merge(schema, config) |
|
|
cli = OmegaConf.from_dotlist(unknown) |
|
|
config = OmegaConf.merge(config, cli) |
|
|
|
|
|
try: |
|
|
main(config, args.wandb_run_name, args.wandb_group_name, app_start_time) |
|
|
finally: |
|
|
wandb.finish() |