File size: 31,514 Bytes
af758d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 |
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import math
import os
import random
import shutil
from pathlib import Path
import einops
import signal
import sys
import accelerate
import numpy as np
import imageio
import PIL
from PIL import Image, ImageDraw
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullOptimStateDictConfig, FullStateDictConfig
)
import transformers
from accelerate import Accelerator, DistributedDataParallelKwargs, AutocastKwargs
from accelerate import FullyShardedDataParallelPlugin, DeepSpeedPlugin
from accelerate.logging import get_logger
from accelerate.state import AcceleratorState
from accelerate.utils import ProjectConfiguration, set_seed, DistributedType
from huggingface_hub import create_repo, upload_folder
from packaging import version
from torchvision import transforms
from tqdm.auto import tqdm
from einops import rearrange, repeat
from transformers.utils import ContextManagers
from typing import Dict, Optional, Tuple, List, Union
from omegaconf import OmegaConf, ListConfig
from dataclasses import dataclass, asdict
import diffusers
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.torch_utils import is_compiled_module
from peft import LoraConfig, get_peft_model_state_dict
from src.utils.random_state_utils import save_random_state
from src.models.recon.model_latent_recon import LatentRecon
from kiui.lpips import LPIPS
from fused_ssim import fused_ssim
from src.models.data import get_multi_dataloader
from src.models.utils.model import encode_latent_time_vae, encode_plucker_vae, repeat_time_spatially
from src.models.utils.cosmos_1_tokenizer import load_cosmos_1_tokenizer
from src.models.utils.render import get_plucker_embedding_and_rays
from src.models.utils.misc import dtype_map
from src.models.utils.model import encode_multi_view_video, load_vae, encode_video
from src.models.utils.loss import compute_loss
from src.models.utils.train import get_most_recent_checkpoint
import time
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.3")
logger = get_logger(__name__, log_level="INFO")
def prepare_config(
config: Dict
):
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != config.local_rank:
config.local_rank = env_local_rank
# Set paths
config.model_pipeline = config.get('model_pipeline', {})
if config.model_pipeline.get('vae_path', None) is None:
config.model_pipeline['vae_path'] = config.pretrained_model_name_or_path
config.model_pipeline['use_lora'] = config.model_pipeline.get('use_lora', False)
def load_model_weights(path_ckpt, transformer):
path_ckpt_model = os.path.join(path_ckpt, 'pytorch_model', 'mp_rank_00_model_states.pt')
model_state = torch.load(path_ckpt_model, map_location="cpu")
model_state = {f'module.{k}': v for k, v in model_state['module'].items()}
transformer.load_state_dict(model_state, strict=False)
def resume_from_ckpt(config, accelerator, transformer):
global_step = 0
first_epoch = 0
loaded_accelerator = False
# Potentially load in the weights and states from a previous save
if config.resume_from_checkpoint:
is_latest_resume = False
if config.resume_from_checkpoint_dir is not None:
path = os.path.basename(config.resume_from_checkpoint)
path_ckpt = os.path.join(config.resume_from_checkpoint_dir, config.resume_from_checkpoint)
else:
if config.resume_from_checkpoint != "latest":
path = os.path.basename(config.resume_from_checkpoint)
else:
# Get the most recent checkpoint
path = get_most_recent_checkpoint(config.output_dir)
if path is None:
# This must be our first time since no checkpoint was found
logger.warning(f"No latest resume checkpoint found, assuming this is our first training session!")
else:
is_latest_resume = True
if path is not None:
path_ckpt = os.path.join(config.output_dir, path)
else:
path_ckpt = None
if path_ckpt is None:
accelerator.print(
f"Checkpoint '{config.resume_from_checkpoint}' does not exist. Starting a new training run."
)
config.resume_from_checkpoint = None
initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path_ckpt}")
try:
accelerator.load_state(path_ckpt) # will also resume the random seed states # type: ignore
loaded_accelerator = True
except Exception as e:
# Only load model without optimizer
print("Failed to load checkpoint: Try to only load model weights")
try:
load_model_weights(path_ckpt, transformer)
print("Loaded only model weights")
except:
logger.warning(f"Failed to load checkpoint: {e}")
if is_latest_resume:
logger.warning("Remove the broken checkpoint and exit.")
if accelerator.is_main_process:
# remove the checkpoint if it fails to load
if path.endswith("bkup"):
logger.warning("Debug NOT removing the broken checkpoint.")
else:
shutil.rmtree(path_ckpt) # type: ignore
exit(1)
global_step = int(path.split("-")[1])
initial_global_step = global_step
first_epoch = 0
else:
initial_global_step = 0
return initial_global_step, global_step, first_epoch, loaded_accelerator
def main(
config: Dict,
wandb_run_name,
wandb_group_name,
app_start_time,
):
prepare_config(config)
logging_dir = os.path.join(config.output_dir, config.logging_dir) # type: ignore
accelerator_project_config = ProjectConfiguration(project_dir=config.output_dir, logging_dir=logging_dir)
find_unused_parameters = (
(config.gradient_accumulation_steps > 1) and
(config.model_pipeline.get('unet_trainable_modules', None) is not None)
)
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=find_unused_parameters)
autocast_kwargs = AutocastKwargs(cache_enabled=config.autocast_cache_enabled)
if config.use_fsdp:
fsdp_plugin = FullyShardedDataParallelPlugin(
state_dict_config=FullStateDictConfig(offload_to_cpu=False, rank0_only=False),
optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=False, rank0_only=False),
use_orig_params=True, # fucking useless True
)
fsdp_plugin.use_orig_params = True # Stupid stupid design in accelerate
assert not config.use_ema, "FSDP does not support EMAModel yet, please consider DeepSpeed"
else:
fsdp_plugin = None
if config.use_deepspeed:
deepspeed_plugin = DeepSpeedPlugin(
gradient_accumulation_steps=config.gradient_accumulation_steps,
zero_stage=2,
gradient_clipping=config.max_grad_norm
)
deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = config.batch_size
if config.deepspeed_type is None:
config.deepspeed_type = config.mixed_precision
if config.deepspeed_type == 'fp16':
deepspeed_plugin.deepspeed_config['fp16'] = {
"enabled": 'auto',
"auto_cast": True,
"initial_scale_power": 16,
}
elif config.deepspeed_type == 'bf16':
deepspeed_plugin.deepspeed_config['bf16'] = {
"enabled": True,
}
else:
deepspeed_plugin = None
accelerator = Accelerator(
gradient_accumulation_steps=config.gradient_accumulation_steps,
mixed_precision=config.mixed_precision,
log_with=config.log_with,
project_config=accelerator_project_config,
fsdp_plugin=fsdp_plugin,
deepspeed_plugin=deepspeed_plugin,
kwargs_handlers=[ddp_kwargs, autocast_kwargs]
)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if config.seed is not None:
# the seed across all processes to make sure that models initialized in the same way
set_seed(config.seed)
else:
print("Not setting a seed")
# Handle the repository creation
if accelerator.is_main_process:
if config.output_dir is not None:
os.makedirs(config.output_dir, exist_ok=True)
OmegaConf.save(config, os.path.join(config.output_dir, "config.yaml"))
def deepspeed_zero_init_disabled_context_manager():
"""
returns either a context list that includes one that will disable zero.Init or an empty context list
"""
deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
if deepspeed_plugin is None:
return []
return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
# Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.
# For this to work properly all models must be run through `accelerate.prepare`. But accelerate
# will try to assign the same optimizer with the same weights to all models during
# `deepspeed.initialize`, which of course doesn't work.
#
# For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2
# frozen models from being partitioned during `zero.Init` which gets called during
# `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding
# across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.
vae = None
with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
vae = load_vae(config.vae_backbone, config.vae_path)
for k, v in config.items():
if isinstance(v, list) or isinstance(v, ListConfig):
config[k] = tuple(v)
config = OmegaConf.structured(config)
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float16
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
config.mixed_precision = accelerator.mixed_precision
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
config.mixed_precision = accelerator.mixed_precision
transformer = LatentRecon(
config,
)
# Use lpips loss
if config.lambda_lpips > 0:
lpips_img_size = config.img_size if not isinstance(config.img_size, int) else [config.img_size, config.img_size]
lpips_loss_module = LPIPS(net='vgg')
lpips_loss_module.requires_grad_(False)
lpips_loss_module = lpips_loss_module.to(accelerator.device)
else:
lpips_loss_module = None
if config.resume_pretrained_model_ckpt:
# TODO don't do this if we are resuming latest
logger.info(f"Loading pretrain ckpt from: {config.resume_pretrained_model_ckpt}")
data = torch.load(config.resume_pretrained_model_ckpt)
transformer.load_state_dict(data["module"])
# Freeze vae and transformer
transformer.train()
transformer.requires_grad_(False)
if config.set_transformer_dtype:
for module in transformer.modules():
module.to(accelerator.device, dtype=weight_dtype)
modules_dtype = [vae]
for module in modules_dtype:
if module is not None:
module.requires_grad_(False)
module.to(accelerator.device, dtype=weight_dtype)
if config.compile_frozen_modules:
vae.encode = torch.compile(vae.encode)
# Add lora support
lora_params = []
if config.model_pipeline['use_lora']:
transformer_lora_config = LoraConfig(
r=config.model_pipeline.get('lora_rank', 64),
lora_alpha=config.model_pipeline.get('lora_alpha', 64),
init_lora_weights=True,
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
transformer.add_adapter(transformer_lora_config)
lora_params = [name for name, p in transformer.named_parameters() if p.requires_grad]
# Function for unwrapping if model was compiled with `torch.compile`.
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
# NOTE: currently only save and load the transformer model
# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
if config.use_ema:
ema_transformer.save_pretrained(os.path.join(output_dir, "transformer_ema"))
for i, model in enumerate(models):
model.save_pretrained(os.path.join(output_dir, "transformer"))
if config.model_pipeline['use_lora']:
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
model.save_lora_weights(os.path.join(output_dir, "lora"), transformer_lora_layers=transformer_lora_layers_to_save)
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
def load_model_hook(models, input_dir):
for _ in range(len(models)):
# pop models so that they are not loaded again
model = models.pop()
# load diffusers style into model
load_model = GSLRMLatent.from_pretrained(input_dir, subfolder="transformer")
model.register_to_config(**load_model.config)
model.load_state_dict(load_model.state_dict())
del load_model
if accelerator.distributed_type not in [DistributedType.FSDP, DistributedType.DEEPSPEED]:
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if config.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
if config.scale_lr:
config.learning_rate = (
config.learning_rate * config.gradient_accumulation_steps * config.batch_size * accelerator.num_processes
)
# Initialize the optimizer
if config.use_8bit_adam:
try:
import bitsandbytes as bnb # type: ignore
except ImportError:
raise ImportError(
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
)
optimizer_cls = bnb.optim.AdamW8bit
else:
optimizer_cls = torch.optim.AdamW
parameters_list = []
param_names = []
# Customize the parameters that need to be trained; if necessary, you can uncomment them yourself.
transformer_trainable_modules = config.model_pipeline.get('transformer_trainable_modules', None) # None means all trainable
if transformer_trainable_modules is not None:
for name, param in transformer.named_parameters():
for module in transformer_trainable_modules:
if module in name:
parameters_list.append(param)
param_names.append(name)
param.requires_grad = True
break
for name, param in transformer.named_parameters():
if name in lora_params and name not in param_names:
parameters_list.append(param)
param_names.append(name)
param.requires_grad = True
else:
transformer.requires_grad_(True)
parameters_list = transformer.parameters()
param_names = [name for name, param in transformer.named_parameters()]
# fsdp - prepare model in advance of optimizer creation
if accelerator.distributed_type == DistributedType.FSDP:
transformer = accelerator.prepare(transformer)
logger.info("***** Parameters list *****")
logger.info(f"{param_names}")
optimizer = optimizer_cls(
parameters_list,
lr=config.learning_rate,
betas=(config.adam_beta1, config.adam_beta2),
weight_decay=config.adam_weight_decay,
eps=config.adam_epsilon,
)
# Scheduler and math around the number of training steps.
global_batch_size = config.batch_size * accelerator.num_processes
overrode_max_train_steps = False
lr_scheduler = get_scheduler(
config.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=config.lr_warmup_steps * accelerator.num_processes, # NOTE: accelerate iterates num_proc times per step, not a bug
num_training_steps=config.max_train_steps * accelerator.num_processes,
)
# Prepare everything with our `accelerator`.
if accelerator.distributed_type == DistributedType.FSDP:
optimizer, lr_scheduler = accelerator.prepare(optimizer, lr_scheduler)
else:
transformer, optimizer, lr_scheduler = accelerator.prepare(transformer, optimizer, lr_scheduler)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
tracker_config = dict(vars(config))
pop_keys = []
for k, v in tracker_config.items():
if v is not None and not isinstance(v, (int, float, str, bool, torch.Tensor)):
pop_keys.append(k)
for k in pop_keys:
tracker_config.pop(k)
init_kwargs = {
"wandb": {
"name": wandb_run_name,
"dir": config.output_dir,
"group": wandb_group_name,
"tags": ["cosmos_3dgs"],
"resume": "auto",
},
}
accelerator.init_trackers(config.experiment_name, config=tracker_config, init_kwargs=init_kwargs)
global_step = 0
first_epoch = 0
# Potentially load in the weights and states from a previous save
initial_global_step, global_step, first_epoch, loaded_accelerator = resume_from_ckpt(config, accelerator, transformer)
# Overwrite learning rate
if config.lr_overwrite:
print(f"Set new optimizer with learning rate {config.learning_rate}")
for param_group in optimizer.param_groups:
param_group['lr'] = config.learning_rate
lr_scheduler = get_scheduler(
config.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=config.lr_warmup_steps * accelerator.num_processes, # NOTE: accelerate iterates num_proc times per step, not a bug
num_training_steps=config.max_train_steps * accelerator.num_processes,
)
# Initialize DataLoaders
if (initial_global_step == 0 or not loaded_accelerator) and config.seed is not None: # reset the seed again,
set_seed(config.seed, device_specific=True) # differ in each process, for data loading
print(f"Set seed to {config.seed}")
# DataLoaders creation:
wds_loader = True
train_dataloader, test_dataloader = get_multi_dataloader(config, accelerator)
def train_step(batch, train_loss, num_input_multi_views):
threedgs_kwargs = config
# Convert images to latent space
if wds_loader:
batch_keys = list(batch.keys())
for k in batch_keys:
if isinstance(batch[k], torch.Tensor):
batch[k] = batch[k].to(accelerator.device, non_blocking=True)
# 3DGS rendering with full precision
if k not in ['intrinsics_input', 'c2ws_input', 'cam_view', 'intrinsics']:
batch[k] = batch[k].to(weight_dtype)
# Read data
gt_images = batch['images_output']
gt_depths = batch.get('depths_output', None)
# Handle variable size in multi_views
if 'num_input_multi_views' in batch:
assert (batch['num_input_multi_views'][0] == batch['num_input_multi_views']).all(), f"Not supporting multi batch size for variable multi-view"
num_input_multi_views = int(batch['num_input_multi_views'][0].item())
batch['num_input_multi_views'] = num_input_multi_views
# Encode video
if 'rgb_latents' in batch:
model_input = batch['rgb_latents'].to(weight_dtype)
else:
video = batch['images_input_vae'].to(weight_dtype)
if threedgs_kwargs.use_rgb_decoder:
model_input = video
else:
# Encode each multi-view video independently
model_input = encode_multi_view_video(vae, video, num_input_multi_views, config.vae_backbone)
batch['images_input_embed'] = model_input
# Compute plucker with GPU for speed
if threedgs_kwargs.get('compute_plucker_cuda', True):
batch['plucker_embedding'], batch['rays_os'], batch['rays_ds'] = get_plucker_embedding_and_rays(
batch['intrinsics_input'],
batch['c2ws_input'],
threedgs_kwargs.img_size,
threedgs_kwargs.patch_size_out_factor,
batch['flip_flag'],
get_batch_index=False,
dtype=dtype_map[threedgs_kwargs.compute_plucker_dtype],
out_dtype=weight_dtype
)
# Encode time and plucker
if threedgs_kwargs.get('use_time_embedding', False) and threedgs_kwargs.get('time_embedding_vae', False):
batch = encode_latent_time_vae(batch, lambda x: encode_video(vae, x, config.vae_backbone), threedgs_kwargs.img_size)
if threedgs_kwargs.get('plucker_embedding_vae', False):
batch = encode_plucker_vae(batch, lambda x: encode_multi_view_video(vae, x, num_input_multi_views, config.vae_backbone))
# Main model pass
model_output = transformer(batch)
# Compute losses
pred_images = model_output['images_pred'].to(gt_images.dtype)
pred_depths = model_output['depths_pred'].to(gt_images.dtype)
pred_opacity = model_output['opacity_pred']
train_loss, loss = compute_loss(accelerator, train_loss, pred_images, gt_images, pred_depths, gt_depths, pred_opacity, config, lpips_loss_module, lpips_img_size)
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(transformer.parameters(), config.max_grad_norm)
# double check the nan/inf, sometime doesn't work in DDP for no reason
if optimizer.scaler is not None:
optimizer.scaler._check_inf_per_device(optimizer.optimizer)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
return train_loss
# Train!
total_batch_size = config.batch_size * accelerator.num_processes * config.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Instantaneous batch size per device = {config.batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {config.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {config.max_train_steps}")
logger.info(f" Output dir: {config.output_dir}")
progress_bar = tqdm(
range(0, config.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
) # type: ignore
break_loop = False
while True:
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(transformer):
train_loss = train_step(batch, train_loss, config.num_input_multi_views)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
accelerator.log({
"train_loss": train_loss,
"lr": lr_scheduler.get_last_lr()[0],
}, step=global_step)
train_loss = 0.0
if global_step % config.checkpointing_steps == 0:
if accelerator.is_main_process:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if config.checkpoints_total_limit is not None:
checkpoints = os.listdir(config.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# don't count "permanent_checkpointing_steps" checkpoints towards limit
checkpoints = [ckpt for ckpt in checkpoints if
(int(ckpt.split("-")[1]) % config.permanent_checkpointing_steps) != 0]
print(checkpoints)
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= config.checkpoints_total_limit:
num_to_remove = len(checkpoints) - config.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(config.output_dir, removing_checkpoint) # type: ignore
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(config.output_dir, f"checkpoint-{global_step}") # type: ignore
if accelerator.is_main_process or accelerator.distributed_type in [DistributedType.FSDP, DistributedType.DEEPSPEED]:
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
if config.save_multi_random_states and not accelerator.is_main_process:
save_path = os.path.join(config.output_dir, f"checkpoint-{global_step}") # type: ignore
os.makedirs(save_path, exist_ok=True)
save_random_state(save_path, accelerator.process_index)
if config.job_stop_steps is not None and global_step % config.job_stop_steps == 0:
logger.info('Reach Job Stop Steps')
break_loop = True
break
logs = {"step_loss": train_loss, "lr": lr_scheduler.get_last_lr()[0], "dir": config.output_dir}
if optimizer.step_was_skipped:
logs["overflow"] = 1
logs["scaler"] = optimizer.scaler._scale.item() if optimizer.scaler is not None else 1
logger.warning(f"Gradient overflow. Skipping step {global_step}, scaler {logs['scaler']}")
progress_bar.set_postfix(**logs)
if global_step >= config.max_train_steps: # type: ignore
logger.info('Reach Max Train Steps')
break_loop = True
break
if break_loop:
break
# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
accelerator.end_training()
if __name__ == "__main__":
app_start_time = time.time_ns() / 1_000_000
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True)
parser.add_argument('--config_default', type=str, default='configs/training/default.yaml')
parser.add_argument('--wandb_run_name', type=str, default=None, help="Name of run for wandb",)
parser.add_argument('--wandb_group_name', type=str, default=None)
args, unknown = parser.parse_known_args()
schema = OmegaConf.load(args.config_default)
config = OmegaConf.load(args.config)
missing_keys = set(config.keys()) - set(schema.keys())
for key in missing_keys:
OmegaConf.update(schema, key, None, force_add=True)
config = OmegaConf.merge(schema, config)
cli = OmegaConf.from_dotlist(unknown)
config = OmegaConf.merge(config, cli)
try:
main(config, args.wandb_run_name, args.wandb_group_name, app_start_time)
finally:
wandb.finish() |