| | """
|
| | trainer.py - wrapper and utility functions for network training
|
| | Compute loss, back-prop, update parameters, logging, etc.
|
| | """
|
| | import os
|
| | from pathlib import Path
|
| | from typing import Optional, Union
|
| |
|
| | import torch
|
| | import torch.distributed
|
| | import torch.optim as optim
|
| |
|
| |
|
| |
|
| | from omegaconf import DictConfig
|
| | from torch.nn.parallel import DistributedDataParallel as DDP
|
| |
|
| | from .model.flow_matching import FlowMatching
|
| | from .model.networks import get_my_mmaudio
|
| | from .model.sequence_config import CONFIG_16K, CONFIG_44K
|
| | from .model.utils.features_utils import FeaturesUtils
|
| | from .model.utils.parameter_groups import get_parameter_groups
|
| | from .model.utils.sample_utils import log_normal_sample
|
| | from .utils.dist_utils import (info_if_rank_zero, local_rank, string_if_rank_zero)
|
| | from .utils.log_integrator import Integrator
|
| | from .utils.logger import TensorboardLogger
|
| | from .utils.time_estimator import PartialTimeEstimator, TimeEstimator
|
| | from .utils.video_joiner import VideoJoiner
|
| |
|
| |
|
| | class Runner:
|
| |
|
| | def __init__(self,
|
| | cfg: DictConfig,
|
| | log: TensorboardLogger,
|
| | run_path: Union[str, Path],
|
| | for_training: bool = True,
|
| | latent_mean: Optional[torch.Tensor] = None,
|
| | latent_std: Optional[torch.Tensor] = None):
|
| | self.exp_id = cfg.exp_id
|
| | self.use_amp = cfg.amp
|
| | self.enable_grad_scaler = cfg.enable_grad_scaler
|
| | self.for_training = for_training
|
| | self.cfg = cfg
|
| |
|
| | if cfg.model.endswith('16k'):
|
| | self.seq_cfg = CONFIG_16K
|
| | mode = '16k'
|
| | elif cfg.model.endswith('44k'):
|
| | self.seq_cfg = CONFIG_44K
|
| | mode = '44k'
|
| | else:
|
| | raise ValueError(f'Unknown model: {cfg.model}')
|
| |
|
| | self.sample_rate = self.seq_cfg.sampling_rate
|
| | self.duration_sec = self.seq_cfg.duration
|
| |
|
| |
|
| | empty_string_feat = torch.load('./ext_weights/empty_string.pth', weights_only=True)[0]
|
| | self.network = DDP(get_my_mmaudio(cfg.model,
|
| | latent_mean=latent_mean,
|
| | latent_std=latent_std,
|
| | empty_string_feat=empty_string_feat).cuda(),
|
| | device_ids=[local_rank],
|
| | broadcast_buffers=False)
|
| | if cfg.compile:
|
| |
|
| |
|
| |
|
| | self.train_fn = torch.compile(self.train_fn)
|
| | self.val_fn = torch.compile(self.val_fn)
|
| |
|
| | self.fm = FlowMatching(cfg.sampling.min_sigma,
|
| | inference_mode=cfg.sampling.method,
|
| | num_steps=cfg.sampling.num_steps)
|
| |
|
| |
|
| | if for_training and cfg.ema.enable and local_rank == 0:
|
| | self.ema = PostHocEMA(self.network.module,
|
| | sigma_rels=cfg.ema.sigma_rels,
|
| | update_every=cfg.ema.update_every,
|
| | checkpoint_every_num_steps=cfg.ema.checkpoint_every,
|
| | checkpoint_folder=cfg.ema.checkpoint_folder,
|
| | step_size_correction=True).cuda()
|
| | self.ema_start = cfg.ema.start
|
| | else:
|
| | self.ema = None
|
| |
|
| | self.rng = torch.Generator(device='cuda')
|
| | self.rng.manual_seed(cfg['seed'] + local_rank)
|
| |
|
| |
|
| | if mode == '16k':
|
| | self.features = FeaturesUtils(
|
| | tod_vae_ckpt=cfg['vae_16k_ckpt'],
|
| | bigvgan_vocoder_ckpt=cfg['bigvgan_vocoder_ckpt'],
|
| | synchformer_ckpt=cfg['synchformer_ckpt'],
|
| | enable_conditions=True,
|
| | mode=mode,
|
| | need_vae_encoder=False,
|
| | )
|
| | elif mode == '44k':
|
| | self.features = FeaturesUtils(
|
| | tod_vae_ckpt=cfg['vae_44k_ckpt'],
|
| | synchformer_ckpt=cfg['synchformer_ckpt'],
|
| | enable_conditions=True,
|
| | mode=mode,
|
| | need_vae_encoder=False,
|
| | )
|
| | self.features = self.features.cuda().eval()
|
| |
|
| | if cfg.compile:
|
| | self.features.compile()
|
| |
|
| |
|
| | self.log_normal_sampling_mean = cfg.sampling.mean
|
| | self.log_normal_sampling_scale = cfg.sampling.scale
|
| | self.null_condition_probability = cfg.null_condition_probability
|
| | self.cfg_strength = cfg.cfg_strength
|
| |
|
| |
|
| | self.log = log
|
| | self.run_path = Path(run_path)
|
| | vgg_cfg = cfg.data.VGGSound
|
| | if for_training:
|
| | self.val_video_joiner = VideoJoiner(vgg_cfg.root, self.run_path / 'val-sampled-videos',
|
| | self.sample_rate, self.duration_sec)
|
| | else:
|
| | self.test_video_joiner = VideoJoiner(vgg_cfg.root,
|
| | self.run_path / 'test-sampled-videos',
|
| | self.sample_rate, self.duration_sec)
|
| | string_if_rank_zero(self.log, 'model_size',
|
| | f'{sum([param.nelement() for param in self.network.parameters()])}')
|
| | string_if_rank_zero(
|
| | self.log, 'number_of_parameters_that_require_gradient: ',
|
| | str(
|
| | sum([
|
| | param.nelement()
|
| | for param in filter(lambda p: p.requires_grad, self.network.parameters())
|
| | ])))
|
| | info_if_rank_zero(self.log, 'torch version: ' + torch.__version__)
|
| | self.train_integrator = Integrator(self.log, distributed=True)
|
| | self.val_integrator = Integrator(self.log, distributed=True)
|
| |
|
| |
|
| | if for_training:
|
| | self.enter_train()
|
| | parameter_groups = get_parameter_groups(self.network, cfg, print_log=(local_rank == 0))
|
| | self.optimizer = optim.AdamW(parameter_groups,
|
| | lr=cfg['learning_rate'],
|
| | weight_decay=cfg['weight_decay'],
|
| | betas=[0.9, 0.95],
|
| | eps=1e-6 if self.use_amp else 1e-8,
|
| | fused=True)
|
| | if self.enable_grad_scaler:
|
| | self.scaler = torch.amp.GradScaler(init_scale=2048)
|
| | self.clip_grad_norm = cfg['clip_grad_norm']
|
| |
|
| |
|
| | linear_warmup_steps = cfg['linear_warmup_steps']
|
| |
|
| | def warmup(currrent_step: int):
|
| | return (currrent_step + 1) / (linear_warmup_steps + 1)
|
| |
|
| | warmup_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=warmup)
|
| |
|
| |
|
| | if cfg['lr_schedule'] == 'constant':
|
| | next_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda _: 1)
|
| | elif cfg['lr_schedule'] == 'poly':
|
| | total_num_iter = cfg['iterations']
|
| | next_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer,
|
| | lr_lambda=lambda x:
|
| | (1 - (x / total_num_iter))**0.9)
|
| | elif cfg['lr_schedule'] == 'step':
|
| | next_scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer,
|
| | cfg['lr_schedule_steps'],
|
| | cfg['lr_schedule_gamma'])
|
| | else:
|
| | raise NotImplementedError
|
| |
|
| | self.scheduler = optim.lr_scheduler.SequentialLR(self.optimizer,
|
| | [warmup_scheduler, next_scheduler],
|
| | [linear_warmup_steps])
|
| |
|
| |
|
| | self.log_text_interval = cfg['log_text_interval']
|
| | self.log_extra_interval = cfg['log_extra_interval']
|
| | self.save_weights_interval = cfg['save_weights_interval']
|
| | self.save_checkpoint_interval = cfg['save_checkpoint_interval']
|
| | self.save_copy_iterations = cfg['save_copy_iterations']
|
| | self.num_iterations = cfg['num_iterations']
|
| | if cfg['debug']:
|
| | self.log_text_interval = self.log_extra_interval = 1
|
| |
|
| |
|
| | self.log.batch_timer = TimeEstimator(self.num_iterations, self.log_text_interval)
|
| |
|
| | self.log.data_timer = PartialTimeEstimator(self.num_iterations, 1, ema_alpha=0.9)
|
| | else:
|
| | self.enter_val()
|
| |
|
| | def train_fn(
|
| | self,
|
| | clip_f: torch.Tensor,
|
| | sync_f: torch.Tensor,
|
| | text_f: torch.Tensor,
|
| | a_mean: torch.Tensor,
|
| | a_std: torch.Tensor,
|
| | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| |
|
| | a_randn = torch.empty_like(a_mean).normal_(generator=self.rng)
|
| | x1 = a_mean + a_std * a_randn
|
| | bs = x1.shape[0]
|
| |
|
| |
|
| | x1 = self.network.module.normalize(x1)
|
| |
|
| | t = log_normal_sample(x1,
|
| | generator=self.rng,
|
| | m=self.log_normal_sampling_mean,
|
| | s=self.log_normal_sampling_scale)
|
| | x0, x1, xt, (clip_f, sync_f, text_f) = self.fm.get_x0_xt_c(x1,
|
| | t,
|
| | Cs=[clip_f, sync_f, text_f],
|
| | generator=self.rng)
|
| |
|
| |
|
| | samples = torch.rand(bs, device=x1.device, generator=self.rng)
|
| | null_video = (samples < self.null_condition_probability)
|
| | clip_f[null_video] = self.network.module.empty_clip_feat
|
| | sync_f[null_video] = self.network.module.empty_sync_feat
|
| |
|
| | samples = torch.rand(bs, device=x1.device, generator=self.rng)
|
| | null_text = (samples < self.null_condition_probability)
|
| | text_f[null_text] = self.network.module.empty_string_feat
|
| |
|
| | pred_v = self.network(xt, clip_f, sync_f, text_f, t)
|
| | loss = self.fm.loss(pred_v, x0, x1)
|
| | mean_loss = loss.mean()
|
| | return x1, loss, mean_loss, t
|
| |
|
| | def val_fn(
|
| | self,
|
| | clip_f: torch.Tensor,
|
| | sync_f: torch.Tensor,
|
| | text_f: torch.Tensor,
|
| | x1: torch.Tensor,
|
| | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| | bs = x1.shape[0]
|
| |
|
| | x1 = self.network.module.normalize(x1)
|
| | t = log_normal_sample(x1,
|
| | generator=self.rng,
|
| | m=self.log_normal_sampling_mean,
|
| | s=self.log_normal_sampling_scale)
|
| | x0, x1, xt, (clip_f, sync_f, text_f) = self.fm.get_x0_xt_c(x1,
|
| | t,
|
| | Cs=[clip_f, sync_f, text_f],
|
| | generator=self.rng)
|
| |
|
| |
|
| | samples = torch.rand(bs, device=x1.device, generator=self.rng)
|
| |
|
| | null_video = (samples < self.null_condition_probability)
|
| |
|
| | clip_f[null_video] = self.network.module.empty_clip_feat
|
| | sync_f[null_video] = self.network.module.empty_sync_feat
|
| |
|
| | samples = torch.rand(bs, device=x1.device, generator=self.rng)
|
| | null_text = (samples < self.null_condition_probability)
|
| | text_f[null_text] = self.network.module.empty_string_feat
|
| |
|
| | pred_v = self.network(xt, clip_f, sync_f, text_f, t)
|
| |
|
| | loss = self.fm.loss(pred_v, x0, x1)
|
| | mean_loss = loss.mean()
|
| | return loss, mean_loss, t
|
| |
|
| | def train_pass(self, data, it: int = 0):
|
| |
|
| | if not self.for_training:
|
| | raise ValueError('train_pass() should not be called when not training.')
|
| |
|
| | self.enter_train()
|
| | with torch.amp.autocast('cuda', enabled=self.use_amp, dtype=torch.bfloat16):
|
| | clip_f = data['clip_features'].cuda(non_blocking=True)
|
| | sync_f = data['sync_features'].cuda(non_blocking=True)
|
| | text_f = data['text_features'].cuda(non_blocking=True)
|
| | video_exist = data['video_exist'].cuda(non_blocking=True)
|
| | text_exist = data['text_exist'].cuda(non_blocking=True)
|
| | a_mean = data['a_mean'].cuda(non_blocking=True)
|
| | a_std = data['a_std'].cuda(non_blocking=True)
|
| |
|
| |
|
| | clip_f[~video_exist] = self.network.module.empty_clip_feat
|
| | sync_f[~video_exist] = self.network.module.empty_sync_feat
|
| | text_f[~text_exist] = self.network.module.empty_string_feat
|
| |
|
| | self.log.data_timer.end()
|
| | if it % self.log_extra_interval == 0:
|
| | unmasked_clip_f = clip_f.clone()
|
| | unmasked_sync_f = sync_f.clone()
|
| | unmasked_text_f = text_f.clone()
|
| | x1, loss, mean_loss, t = self.train_fn(clip_f, sync_f, text_f, a_mean, a_std)
|
| |
|
| | self.train_integrator.add_dict({'loss': mean_loss})
|
| |
|
| | if it % self.log_text_interval == 0 and it != 0:
|
| | self.train_integrator.add_scalar('lr', self.scheduler.get_last_lr()[0])
|
| | self.train_integrator.add_binned_tensor('binned_loss', loss, t)
|
| | self.train_integrator.finalize('train', it)
|
| | self.train_integrator.reset_except_hooks()
|
| |
|
| |
|
| | self.optimizer.zero_grad(set_to_none=True)
|
| | if self.enable_grad_scaler:
|
| | self.scaler.scale(mean_loss).backward()
|
| | self.scaler.unscale_(self.optimizer)
|
| | grad_norm = torch.nn.utils.clip_grad_norm_(self.network.parameters(),
|
| | self.clip_grad_norm)
|
| | self.scaler.step(self.optimizer)
|
| | self.scaler.update()
|
| | else:
|
| | mean_loss.backward()
|
| | grad_norm = torch.nn.utils.clip_grad_norm_(self.network.parameters(),
|
| | self.clip_grad_norm)
|
| | self.optimizer.step()
|
| |
|
| | if self.ema is not None and it >= self.ema_start:
|
| | self.ema.update()
|
| | self.scheduler.step()
|
| | self.integrator.add_scalar('grad_norm', grad_norm)
|
| |
|
| | self.enter_val()
|
| | with torch.amp.autocast('cuda', enabled=self.use_amp,
|
| | dtype=torch.bfloat16), torch.inference_mode():
|
| | try:
|
| | if it % self.log_extra_interval == 0:
|
| |
|
| |
|
| | x1 = self.network.module.unnormalize(x1[0:1])
|
| | mel = self.features.decode(x1)
|
| | audio = self.features.vocode(mel).cpu()[0]
|
| | self.log.log_spectrogram('train', f'spec-gt-r{local_rank}', mel.cpu()[0], it)
|
| | self.log.log_audio('train',
|
| | f'audio-gt-r{local_rank}',
|
| | audio,
|
| | it,
|
| | sample_rate=self.sample_rate)
|
| |
|
| |
|
| | x0 = torch.empty_like(x1[0:1]).normal_(generator=self.rng)
|
| | clip_f = unmasked_clip_f[0:1]
|
| | sync_f = unmasked_sync_f[0:1]
|
| | text_f = unmasked_text_f[0:1]
|
| | conditions = self.network.module.preprocess_conditions(clip_f, sync_f, text_f)
|
| | empty_conditions = self.network.module.get_empty_conditions(x0.shape[0])
|
| | cfg_ode_wrapper = lambda t, x: self.network.module.ode_wrapper(
|
| | t, x, conditions, empty_conditions, self.cfg_strength)
|
| | x1_hat = self.fm.to_data(cfg_ode_wrapper, x0)
|
| | x1_hat = self.network.module.unnormalize(x1_hat)
|
| | mel = self.features.decode(x1_hat)
|
| | audio = self.features.vocode(mel).cpu()[0]
|
| | self.log.log_spectrogram('train', f'spec-r{local_rank}', mel.cpu()[0], it)
|
| | self.log.log_audio('train',
|
| | f'audio-r{local_rank}',
|
| | audio,
|
| | it,
|
| | sample_rate=self.sample_rate)
|
| | except Exception as e:
|
| | self.log.warning(f'Error in extra logging: {e}')
|
| | if self.cfg.debug:
|
| | raise
|
| |
|
| |
|
| | save_copy = it in self.save_copy_iterations
|
| |
|
| | if (it % self.save_weights_interval == 0 and it != 0) or save_copy:
|
| | self.save_weights(it)
|
| |
|
| | if it % self.save_checkpoint_interval == 0 and it != 0:
|
| | self.save_checkpoint(it, save_copy=save_copy)
|
| |
|
| | self.log.data_timer.start()
|
| |
|
| | @torch.inference_mode()
|
| | def validation_pass(self, data, it: int = 0):
|
| | self.enter_val()
|
| | with torch.amp.autocast('cuda', enabled=self.use_amp, dtype=torch.bfloat16):
|
| | clip_f = data['clip_features'].cuda(non_blocking=True)
|
| | sync_f = data['sync_features'].cuda(non_blocking=True)
|
| | text_f = data['text_features'].cuda(non_blocking=True)
|
| | video_exist = data['video_exist'].cuda(non_blocking=True)
|
| | text_exist = data['text_exist'].cuda(non_blocking=True)
|
| | a_mean = data['a_mean'].cuda(non_blocking=True)
|
| | a_std = data['a_std'].cuda(non_blocking=True)
|
| |
|
| | clip_f[~video_exist] = self.network.module.empty_clip_feat
|
| | sync_f[~video_exist] = self.network.module.empty_sync_feat
|
| | text_f[~text_exist] = self.network.module.empty_string_feat
|
| | a_randn = torch.empty_like(a_mean).normal_(generator=self.rng)
|
| | x1 = a_mean + a_std * a_randn
|
| |
|
| | self.log.data_timer.end()
|
| | loss, mean_loss, t = self.val_fn(clip_f.clone(), sync_f.clone(), text_f.clone(), x1)
|
| |
|
| | self.val_integrator.add_binned_tensor('binned_loss', loss, t)
|
| | self.val_integrator.add_dict({'loss': mean_loss})
|
| |
|
| | self.log.data_timer.start()
|
| |
|
| | @torch.inference_mode()
|
| | def inference_pass(self,
|
| | data,
|
| | it: int,
|
| | data_cfg: DictConfig,
|
| | *,
|
| | save_eval: bool = True) -> Path:
|
| | self.enter_val()
|
| | with torch.amp.autocast('cuda', enabled=self.use_amp, dtype=torch.bfloat16):
|
| | clip_f = data['clip_features'].cuda(non_blocking=True)
|
| | sync_f = data['sync_features'].cuda(non_blocking=True)
|
| | text_f = data['text_features'].cuda(non_blocking=True)
|
| | video_exist = data['video_exist'].cuda(non_blocking=True)
|
| | text_exist = data['text_exist'].cuda(non_blocking=True)
|
| | a_mean = data['a_mean'].cuda(non_blocking=True)
|
| |
|
| | clip_f[~video_exist] = self.network.module.empty_clip_feat
|
| | sync_f[~video_exist] = self.network.module.empty_sync_feat
|
| | text_f[~text_exist] = self.network.module.empty_string_feat
|
| |
|
| |
|
| | x0 = torch.empty_like(a_mean).normal_(generator=self.rng)
|
| | conditions = self.network.module.preprocess_conditions(clip_f, sync_f, text_f)
|
| | empty_conditions = self.network.module.get_empty_conditions(x0.shape[0])
|
| | cfg_ode_wrapper = lambda t, x: self.network.module.ode_wrapper(
|
| | t, x, conditions, empty_conditions, self.cfg_strength)
|
| | x1_hat = self.fm.to_data(cfg_ode_wrapper, x0)
|
| | x1_hat = self.network.module.unnormalize(x1_hat)
|
| | mel = self.features.decode(x1_hat)
|
| | audio = self.features.vocode(mel).cpu()
|
| | for i in range(audio.shape[0]):
|
| | video_id = data['id'][i]
|
| | if (not self.for_training) and i == 0:
|
| |
|
| | self.test_video_joiner.join(video_id, f'{video_id}', audio[i].transpose(0, 1))
|
| |
|
| | if data_cfg.output_subdir is not None:
|
| |
|
| | if save_eval:
|
| | iter_naming = f'{it:09d}'
|
| | else:
|
| | iter_naming = 'val-cache'
|
| | audio_dir = self.log.log_audio(iter_naming,
|
| | f'{video_id}',
|
| | audio[i],
|
| | it=None,
|
| | sample_rate=self.sample_rate,
|
| | subdir=Path(data_cfg.output_subdir))
|
| | if save_eval and i == 0:
|
| | self.val_video_joiner.join(video_id, f'{iter_naming}-{video_id}',
|
| | audio[i].transpose(0, 1))
|
| | else:
|
| |
|
| | audio_dir = self.log.log_audio(f'{data_cfg.tag}-sampled',
|
| | f'{video_id}',
|
| | audio[i],
|
| | it=None,
|
| | sample_rate=self.sample_rate)
|
| |
|
| | return Path(audio_dir)
|
| |
|
| | @torch.inference_mode()
|
| | def eval(self, audio_dir: Path, it: int, data_cfg: DictConfig) -> dict[str, float]:
|
| | with torch.amp.autocast('cuda', enabled=False):
|
| | if local_rank == 0:
|
| | extract(audio_path=audio_dir,
|
| | output_path=audio_dir / 'cache',
|
| | device='cuda',
|
| | batch_size=32,
|
| | audio_length=8)
|
| | output_metrics = evaluate(gt_audio_cache=Path(data_cfg.gt_cache),
|
| | pred_audio_cache=audio_dir / 'cache')
|
| | for k, v in output_metrics.items():
|
| |
|
| |
|
| | self.log.log_scalar(f'{data_cfg.tag}/{k}', v, it)
|
| | self.log.info(f'{data_cfg.tag}/{k:<10}: {v:.10f}')
|
| | else:
|
| | output_metrics = None
|
| |
|
| | return output_metrics
|
| |
|
| | def save_weights(self, it, save_copy=False):
|
| | if local_rank != 0:
|
| | return
|
| |
|
| | os.makedirs(self.run_path, exist_ok=True)
|
| | if save_copy:
|
| | model_path = self.run_path / f'{self.exp_id}_{it}.pth'
|
| | torch.save(self.network.module.state_dict(), model_path)
|
| | self.log.info(f'Network weights saved to {model_path}.')
|
| |
|
| |
|
| | model_path = self.run_path / f'{self.exp_id}_last.pth'
|
| | if model_path.exists():
|
| | shadow_path = model_path.with_name(model_path.name.replace('last', 'shadow'))
|
| | model_path.replace(shadow_path)
|
| | self.log.info(f'Network weights shadowed to {shadow_path}.')
|
| |
|
| | torch.save(self.network.module.state_dict(), model_path)
|
| | self.log.info(f'Network weights saved to {model_path}.')
|
| |
|
| | def save_checkpoint(self, it, save_copy=False):
|
| | if local_rank != 0:
|
| | return
|
| |
|
| | checkpoint = {
|
| | 'it': it,
|
| | 'weights': self.network.module.state_dict(),
|
| | 'optimizer': self.optimizer.state_dict(),
|
| | 'scheduler': self.scheduler.state_dict(),
|
| | 'ema': self.ema.state_dict() if self.ema is not None else None,
|
| | }
|
| |
|
| | os.makedirs(self.run_path, exist_ok=True)
|
| | if save_copy:
|
| | model_path = self.run_path / f'{self.exp_id}_ckpt_{it}.pth'
|
| | torch.save(checkpoint, model_path)
|
| | self.log.info(f'Checkpoint saved to {model_path}.')
|
| |
|
| |
|
| | model_path = self.run_path / f'{self.exp_id}_ckpt_last.pth'
|
| | if model_path.exists():
|
| | shadow_path = model_path.with_name(model_path.name.replace('last', 'shadow'))
|
| | model_path.replace(shadow_path)
|
| | self.log.info(f'Checkpoint shadowed to {shadow_path}.')
|
| |
|
| | torch.save(checkpoint, model_path)
|
| | self.log.info(f'Checkpoint saved to {model_path}.')
|
| |
|
| | def get_latest_checkpoint_path(self):
|
| | ckpt_path = self.run_path / f'{self.exp_id}_ckpt_last.pth'
|
| | if not ckpt_path.exists():
|
| | info_if_rank_zero(self.log, f'No checkpoint found at {ckpt_path}.')
|
| | return None
|
| | return ckpt_path
|
| |
|
| | def get_latest_weight_path(self):
|
| | weight_path = self.run_path / f'{self.exp_id}_last.pth'
|
| | if not weight_path.exists():
|
| | self.log.info(f'No weight found at {weight_path}.')
|
| | return None
|
| | return weight_path
|
| |
|
| | def get_final_ema_weight_path(self):
|
| | weight_path = self.run_path / f'{self.exp_id}_ema_final.pth'
|
| | if not weight_path.exists():
|
| | self.log.info(f'No weight found at {weight_path}.')
|
| | return None
|
| | return weight_path
|
| |
|
| | def load_checkpoint(self, path):
|
| |
|
| | map_location = 'cuda:%d' % local_rank
|
| | checkpoint = torch.load(path, map_location={'cuda:0': map_location}, weights_only=True)
|
| |
|
| | it = checkpoint['it']
|
| | weights = checkpoint['weights']
|
| | optimizer = checkpoint['optimizer']
|
| | scheduler = checkpoint['scheduler']
|
| | if self.ema is not None:
|
| | self.ema.load_state_dict(checkpoint['ema'])
|
| | self.log.info(f'EMA states loaded from step {self.ema.step}')
|
| |
|
| | map_location = 'cuda:%d' % local_rank
|
| | self.network.module.load_state_dict(weights)
|
| | self.optimizer.load_state_dict(optimizer)
|
| | self.scheduler.load_state_dict(scheduler)
|
| |
|
| | self.log.info(f'Global iteration {it} loaded.')
|
| | self.log.info('Network weights, optimizer states, and scheduler states loaded.')
|
| |
|
| | return it
|
| |
|
| | def load_weights_in_memory(self, src_dict):
|
| | self.network.module.load_weights(src_dict)
|
| | self.log.info('Network weights loaded from memory.')
|
| |
|
| | def load_weights(self, path):
|
| |
|
| | map_location = 'cuda:%d' % local_rank
|
| | src_dict = torch.load(path, map_location={'cuda:0': map_location}, weights_only=True)
|
| |
|
| | self.log.info(f'Importing network weights from {path}...')
|
| | self.load_weights_in_memory(src_dict)
|
| |
|
| | def weights(self):
|
| | return self.network.module.state_dict()
|
| |
|
| | def enter_train(self):
|
| | self.integrator = self.train_integrator
|
| | self.network.train()
|
| | return self
|
| |
|
| | def enter_val(self):
|
| | self.network.eval()
|
| | return self
|
| |
|