| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import traceback |
| import os |
| import time |
| import math |
| import argparse |
| import shutil |
| import torch |
| import safetensors |
| from omegaconf import OmegaConf |
| from abc import abstractmethod |
| from contextlib import contextmanager |
| from accelerate import Accelerator |
| from accelerate.logging import get_logger |
| from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed |
| import cv2 |
| import numpy as np |
|
|
| from lam.utils.logging import configure_logger |
| from lam.utils.compile import configure_dynamo |
| from lam.runners.abstract import Runner |
|
|
|
|
| logger = get_logger(__name__) |
|
|
|
|
| def parse_configs(): |
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--config', type=str, default='./assets/config.yaml') |
| parser.add_argument('--resume', type=str, default='') |
| args, unknown = parser.parse_known_args() |
|
|
| |
| cfg = OmegaConf.load(args.config) |
|
|
| |
| cli_cfg = OmegaConf.from_cli(unknown) |
| cfg = OmegaConf.merge(cfg, cli_cfg) |
| if len(args.resume) > 0: |
| cfg.train.resume = args.resume |
|
|
| return cfg |
|
|
|
|
| class Trainer(Runner): |
|
|
| def __init__(self): |
| super().__init__() |
|
|
| self.cfg = parse_configs() |
| self.has_disc = self.cfg.model.has_disc if hasattr(self.cfg.model, "has_disc") else False |
| |
| self.timestamp = time.strftime("%Y%m%d-%H%M%S") |
|
|
| self.accelerator = Accelerator( |
| mixed_precision=self.cfg.train.mixed_precision, |
| gradient_accumulation_steps=self.cfg.train.accum_steps, |
| log_with=tuple(self.cfg.logger.trackers), |
| project_config=ProjectConfiguration( |
| logging_dir=self.cfg.logger.tracker_root, |
| ), |
| use_seedable_sampler=True, |
| kwargs_handlers=[ |
| DistributedDataParallelKwargs( |
| find_unused_parameters=self.cfg.train.find_unused_parameters, |
| ), |
| ], |
| ) |
|
|
| self.weight_dtype = self.get_weight_dtype() |
| print(f"weight_dtype:{self.weight_dtype}") |
|
|
| set_seed(self.cfg.experiment.seed, device_specific=True) |
| with self.accelerator.main_process_first(): |
| configure_logger( |
| stream_level=self.cfg.logger.stream_level, |
| log_level=self.cfg.logger.log_level, |
| file_path=os.path.join( |
| self.cfg.logger.log_root, |
| self.cfg.experiment.parent, self.cfg.experiment.child, |
| f"{self.timestamp}.log", |
| ) if self.accelerator.is_main_process else None, |
| ) |
| logger.info(self.accelerator.state, main_process_only=False, in_order=True) |
| configure_dynamo(dict(self.cfg.compile)) |
|
|
| |
| self.model : torch.nn.Module = None |
| self.optimizer: torch.optim.Optimizer = None |
| self.scheduler: torch.optim.lr_scheduler.LRScheduler = None |
| self.train_loader: torch.utils.data.DataLoader = None |
| self.val_loader: torch.utils.data.DataLoader = None |
| self.N_max_global_steps: int = None |
| self.N_global_steps_per_epoch: int = None |
| self.global_step: int = 0 |
| self.current_epoch: int = 0 |
|
|
| def __enter__(self): |
| self.accelerator.init_trackers( |
| project_name=f"{self.cfg.experiment.parent}/{self.cfg.experiment.child}", |
| ) |
| self.prepare_everything() |
| self.log_inital_info() |
|
|
| |
| self.trackers_logging_dir = f"{self.cfg.logger.tracker_root}/{self.cfg.experiment.parent}/{self.cfg.experiment.child}" |
| os.makedirs(self.trackers_logging_dir, exist_ok=True) |
| |
| self.snapshot_cfg(self.cfg) |
| |
| return self |
| |
| def get_weight_dtype(self): |
| weight_dtype = torch.float32 |
| if self.accelerator.mixed_precision == "fp16": |
| weight_dtype = torch.float16 |
| elif self.accelerator.mixed_precision == "bf16": |
| weight_dtype = torch.bfloat16 |
| elif self.accelerator.mixed_precision == "no": |
| weight_dtype = torch.float32 |
| else: |
| raise NotImplementedError |
| return weight_dtype |
| |
| def __exit__(self, exc_type, exc_val, exc_tb): |
| self.accelerator.end_training() |
|
|
| @staticmethod |
| def control(option: str = None, synchronized: bool = False): |
| def decorator(func): |
| def wrapper(self, *args, **kwargs): |
| if option is None or hasattr(self.accelerator, option): |
| accelerated_func = getattr(self.accelerator, option)(func) if option is not None else func |
| result = accelerated_func(self, *args, **kwargs) |
| if synchronized: |
| self.accelerator.wait_for_everyone() |
| return result |
| else: |
| raise AttributeError(f"Accelerator has no attribute {option}") |
| return wrapper |
| return decorator |
|
|
| @contextmanager |
| def exec_in_order(self): |
| for rank in range(self.accelerator.num_processes): |
| try: |
| if self.accelerator.process_index == rank: |
| yield |
| finally: |
| self.accelerator.wait_for_everyone() |
|
|
| @property |
| def device(self): |
| return self.accelerator.device |
|
|
| @property |
| def is_distributed(self) -> bool: |
| return self.accelerator.num_processes > 1 |
|
|
| def prepare_everything(self, is_dist_validation: bool = True): |
| |
| if is_dist_validation: |
| if not self.has_disc: |
| self.model, self.optimizer, self.train_loader, self.val_loader = \ |
| self.accelerator.prepare( |
| self.model, self.optimizer, self.train_loader, self.val_loader, |
| ) |
| else: |
| self.model, self.model_disc, self.optimizer, self.optimizer_disc, self.train_loader, self.val_loader = \ |
| self.accelerator.prepare( |
| self.model, self.model_disc, self.optimizer, self.optimizer_disc, self.train_loader, self.val_loader, |
| ) |
| else: |
| if not self.has_disc: |
| self.model, self.optimizer, self.train_loader = \ |
| self.accelerator.prepare( |
| self.model, self.optimizer, self.train_loader, |
| ) |
| else: |
| self.model, self.model_disc, self.optimizer, self.optimizer_disc, self.train_loader = \ |
| self.accelerator.prepare( |
| self.model, self.model_disc, self.optimizer, self.optimizer_disc, self.train_loader, |
| ) |
| |
| self.accelerator.register_for_checkpointing(self.scheduler) |
| if self.has_disc: |
| self.accelerator.register_for_checkpointing(self.scheduler_disc) |
| |
| N_total_batch_size = self.cfg.train.batch_size * self.accelerator.num_processes * self.cfg.train.accum_steps |
| self.N_global_steps_per_epoch = math.ceil(len(self.train_loader) / self.cfg.train.accum_steps) |
| self.N_max_global_steps = self.N_global_steps_per_epoch * self.cfg.train.epochs |
| if self.cfg.train.debug_global_steps is not None: |
| logger.warning(f"Overriding max global steps from {self.N_max_global_steps} to {self.cfg.train.debug_global_steps}") |
| self.N_max_global_steps = self.cfg.train.debug_global_steps |
| print(f"======== Trainable parameters ========") |
| print(f"** Total: {sum(p.numel() for p in self.model.parameters() if p.requires_grad) / 1e6}M") |
| logger.info(f"======== Statistics ========") |
| logger.info(f"** N_max_global_steps: {self.N_max_global_steps}") |
| logger.info(f"** N_total_batch_size: {N_total_batch_size}") |
| logger.info(f"** N_epochs: {self.cfg.train.epochs}") |
| logger.info(f"** N_global_steps_per_epoch: {self.N_global_steps_per_epoch}") |
| logger.debug(f"** Prepared loader length: {len(self.train_loader)}") |
| logger.info(f"** Distributed validation: {is_dist_validation}") |
| logger.info(f"============================") |
| logger.info(f"======== Trainable parameters ========") |
| logger.info(f"** Total: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}") |
| for sub_name, sub_module in self.accelerator.unwrap_model(self.model).named_children(): |
| logger.info(f"** {sub_name}: {sum(p.numel() for p in sub_module.parameters() if p.requires_grad)}") |
| logger.info(f"=====================================") |
| self.accelerator.wait_for_everyone() |
| |
| self.load_ckpt_or_auto_resume_(self.cfg) |
| |
| self.register_hooks() |
|
|
| @abstractmethod |
| def register_hooks(self): |
| pass |
|
|
| def auto_resume_(self, cfg, ckpt_root=None) -> bool: |
| if ckpt_root is None: |
| ckpt_root = os.path.join( |
| cfg.saver.checkpoint_root, |
| cfg.experiment.parent, cfg.experiment.child, |
| ) |
| if not os.path.exists(ckpt_root): |
| return False |
| ckpt_dirs = os.listdir(ckpt_root) |
| if len(ckpt_dirs) == 0: |
| return False |
| ckpt_dirs.sort() |
| latest_ckpt = ckpt_dirs[-1] |
| latest_ckpt_dir = os.path.join(ckpt_root, latest_ckpt) |
| logger.info(f"======== Auto-resume from {latest_ckpt_dir} ========") |
| self.accelerator.load_state(latest_ckpt_dir) |
| self.global_step = int(latest_ckpt) |
| self.current_epoch = self.global_step // self.N_global_steps_per_epoch |
| return True |
|
|
| def load_model_(self, cfg): |
| logger.info(f"======== Loading model from {cfg.saver.load_model} ========") |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| try: |
| safetensors.torch.load_model( |
| self.accelerator.unwrap_model(self.model), |
| cfg.saver.load_model, |
| strict=cfg.saver.load_model_strict if hasattr(cfg.saver, "load_model_strict") else True, |
| ) |
| except: |
| traceback.print_exc() |
| model = self.accelerator.unwrap_model(self.model) |
| model_state_dict = model.state_dict() |
| state_dict = safetensors.torch.load_file(cfg.saver.load_model, device='cpu') |
| for key in list(state_dict): |
| if "renderer.flame_model" in key: |
| print(f"pop:{key}, shape:{state_dict[key].shape}") |
| state_dict.pop(key) |
| if "renderer.flame_model" in key: |
| print(f"pop:{key}, shape:{state_dict[key].shape}") |
| state_dict.pop(key) |
| if "renderer.gs_net.out_layers.scaling.weight" == key: |
| if state_dict["renderer.gs_net.out_layers.scaling.weight"].shape != model_state_dict["renderer.gs_net.out_layers.scaling.weight"].shape: |
| |
| |
| state_dict.pop("renderer.gs_net.out_layers.scaling.weight") |
| state_dict.pop("renderer.gs_net.out_layers.scaling.bias") |
|
|
| missing, unexpected = model.load_state_dict(state_dict, strict=False) |
| missing = set(missing) |
| print("missing:", missing) |
| print("unexpected:", unexpected) |
| |
| if self.has_disc and cfg.saver.get("load_model_disc", None) is not None: |
| safetensors.torch.load_model( |
| self.accelerator.unwrap_model(self.model_disc), |
| cfg.saver.load_model_disc, |
| strict=cfg.saver.load_model_strict if hasattr(cfg.saver, "load_model_strict") else True, |
| ) |
| logger.info(f"======== Model loaded ========") |
|
|
| @control(synchronized=True) |
| def load_ckpt_or_auto_resume_(self, cfg): |
| |
| |
| |
| if hasattr(cfg.saver, "load_ckpt") and cfg.saver.load_ckpt: |
| successful_resume = self.auto_resume_(cfg, ckpt_root=cfg.saver.load_ckpt) |
| if successful_resume: |
| return |
| |
| if cfg.saver.auto_resume: |
| successful_resume = self.auto_resume_(cfg) |
| if successful_resume: |
| return |
|
|
| if cfg.saver.load_model: |
| successful_load = self.load_model_(cfg) |
| if successful_load: |
| return |
| logger.debug(f"======== No checkpoint or model is loaded ========") |
|
|
| |
| |
| def _save_checkpoint(self): |
| ckpt_dir = os.path.join( |
| self.cfg.saver.checkpoint_root, |
| self.cfg.experiment.parent, self.cfg.experiment.child, |
| f"{self.global_step:06d}", |
| ) |
| self.accelerator.save_state(output_dir=ckpt_dir, safe_serialization=True) |
| logger.info(f"======== Saved checkpoint at global step {self.global_step} ========") |
| |
| ckpt_dirs = os.listdir(os.path.dirname(ckpt_dir)) |
| ckpt_dirs.sort() |
| max_ckpt = int(ckpt_dirs[-1]) |
| ckpt_base = int(self.cfg.saver.checkpoint_keep_level) |
| ckpt_period = self.cfg.saver.checkpoint_global_steps |
| logger.debug(f"Checkpoint base: {ckpt_base}") |
| logger.debug(f"Checkpoint period: {ckpt_period}") |
| cur_order = ckpt_base ** math.floor(math.log(max_ckpt // ckpt_period, ckpt_base)) |
| cur_idx = 0 |
| while cur_order > 0: |
| cur_digit = max_ckpt // ckpt_period // cur_order % ckpt_base |
| while cur_idx < len(ckpt_dirs) and int(ckpt_dirs[cur_idx]) // ckpt_period // cur_order % ckpt_base < cur_digit: |
| if int(ckpt_dirs[cur_idx]) // ckpt_period % cur_order != 0: |
| shutil.rmtree(os.path.join(os.path.dirname(ckpt_dir), ckpt_dirs[cur_idx])) |
| logger.info(f"Removed checkpoint {ckpt_dirs[cur_idx]}") |
| cur_idx += 1 |
| cur_order //= ckpt_base |
|
|
| def save_checkpoint(self): |
| if self.accelerator.state.deepspeed_plugin is not None: |
| logger.info("deepspeed mode to save ckpt...............") |
| self._save_checkpoint() |
| else: |
| if self.accelerator.is_main_process: |
| self._save_checkpoint() |
| |
| @control('on_main_process') |
| def snapshot_cfg(self, cfg): |
| |
| save_path=os.path.join(self.trackers_logging_dir, "config.yaml") |
| OmegaConf.save(cfg, save_path) |
| |
| @property |
| def global_step_in_epoch(self): |
| return self.global_step % self.N_global_steps_per_epoch |
|
|
| @abstractmethod |
| def _build_model(self): |
| pass |
|
|
| @abstractmethod |
| def _build_optimizer(self): |
| pass |
|
|
| @abstractmethod |
| def _build_scheduler(self): |
| pass |
|
|
| @abstractmethod |
| def _build_dataloader(self): |
| pass |
|
|
| @abstractmethod |
| def _build_loss_fn(self): |
| pass |
|
|
| @abstractmethod |
| def train(self): |
| pass |
|
|
| @abstractmethod |
| def evaluate(self): |
| pass |
|
|
| @staticmethod |
| def _get_str_progress(epoch: int = None, step: int = None): |
| if epoch is not None: |
| log_type = 'epoch' |
| log_progress = epoch |
| elif step is not None: |
| log_type = 'step' |
| log_progress = step |
| else: |
| raise ValueError('Either epoch or step must be provided') |
| return log_type, log_progress |
|
|
| @control('on_main_process') |
| def log_scalar_kwargs(self, epoch: int = None, step: int = None, split: str = None, **scalar_kwargs): |
| log_type, log_progress = self._get_str_progress(epoch, step) |
| split = f'/{split}' if split else '' |
| for key, value in scalar_kwargs.items(): |
| self.accelerator.log({f'{key}{split}/{log_type}': value}, log_progress) |
|
|
| def log_images_each_process(self, values: dict, step: int | None = None, log_kwargs: dict | None = {}): |
| for tracker in self.accelerator.trackers: |
| if hasattr(tracker, 'log_images'): |
| tracker.log_images(values, step=step, **log_kwargs.get(tracker.name, {})) |
| |
| log_dir = self.trackers_logging_dir |
| if log_kwargs.get("imwrite_image", True): |
| for k, v in values.items(): |
| v = v[0].permute(1, 2, 0).detach().cpu().numpy() |
| save_path = os.path.join(log_dir, f"{step:05d}_{k.replace('/', '_')}.jpg") |
| |
| cv2.imwrite(save_path, (v * 255).astype(np.uint8)[:, :, (2, 1, 0)]) |
|
|
| @control('on_main_process') |
| def log_images(self, values: dict, step: int | None = None, log_kwargs: dict | None = {}): |
| self.log_images_each_process(values, step, log_kwargs) |
| |
| |
| @control('on_main_process') |
| def log_optimizer(self, epoch: int = None, step: int = None, attrs: list[str] = [], group_ids: list[int] = []): |
| log_type, log_progress = self._get_str_progress(epoch, step) |
| assert self.optimizer is not None, 'Optimizer is not initialized' |
| if not attrs: |
| logger.warning('No optimizer attributes are provided, nothing will be logged') |
| if not group_ids: |
| logger.warning('No optimizer group ids are provided, nothing will be logged') |
| for attr in attrs: |
| assert attr in ['lr', 'momentum', 'weight_decay'], f'Invalid optimizer attribute {attr}' |
| for group_id in group_ids: |
| self.accelerator.log({f'opt/{attr}/{group_id}': self.optimizer.param_groups[group_id][attr]}, log_progress) |
|
|
| @control('on_main_process') |
| def log_inital_info(self): |
| assert self.model is not None, 'Model is not initialized' |
| assert self.optimizer is not None, 'Optimizer is not initialized' |
| assert self.scheduler is not None, 'Scheduler is not initialized' |
| self.accelerator.log({'Config': "```\n" + OmegaConf.to_yaml(self.cfg) + "\n```"}) |
| self.accelerator.log({'Model': "```\n" + str(self.model) + "\n```"}) |
| self.accelerator.log({'Optimizer': "```\n" + str(self.optimizer) + "\n```"}) |
| self.accelerator.log({'Scheduler': "```\n" + str(self.scheduler) + "\n```"}) |
|
|
| def run(self): |
| self.train() |
|
|