Spaces:
Runtime error
Runtime error
| # Copyright (c) 2023-2024, Zexin He | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # https://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import time | |
| import math | |
| import argparse | |
| import shutil | |
| import torch | |
| import safetensors | |
| from omegaconf import OmegaConf | |
| from abc import abstractmethod | |
| from contextlib import contextmanager | |
| from accelerate import Accelerator | |
| from accelerate.logging import get_logger | |
| from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed | |
| from openlrm.utils.logging import configure_logger | |
| from openlrm.utils.compile import configure_dynamo | |
| from openlrm.runners.abstract import Runner | |
| from collections import OrderedDict | |
| from huggingface_hub import hf_hub_download | |
| # def my_save_pre_hook(models, weights, output_dir): | |
| # keep = ["_lora", "synthesizer", "front_back_conv"] | |
| # for weight_dict in weights: | |
| # keys_to_keep = [key for key in weight_dict if any(keep_str in key for keep_str in keep)] | |
| # new_weight_dict = OrderedDict((key, weight_dict[key]) for key in keys_to_keep) | |
| # weight_dict.clear() | |
| # weight_dict.update(new_weight_dict) | |
| from collections import OrderedDict | |
| def my_save_pre_hook(models, weights, output_dir): | |
| assert len(models) == len(weights), "Models and weights must correspond one-to-one" | |
| filtered_weights_list = [] | |
| for model, model_weights in zip(models, weights): | |
| filtered_weights = OrderedDict() | |
| for name, param in model.named_parameters(): | |
| if param.requires_grad: | |
| if name in model_weights: | |
| filtered_weights[name] = model_weights[name] | |
| filtered_weights_list.append(filtered_weights) | |
| weights.clear() | |
| weights.extend(filtered_weights_list) | |
| logger = get_logger(__name__) | |
| def parse_configs(): | |
| # Define argparse arguments | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--config', type=str, default='./assets/config.yaml') | |
| args, unknown = parser.parse_known_args() | |
| # Load configuration file | |
| cfg = OmegaConf.load(args.config) | |
| # Override with command-line arguments | |
| cli_cfg = OmegaConf.from_cli(unknown) | |
| cfg = OmegaConf.merge(cfg, cli_cfg) | |
| return cfg | |
| class Trainer(Runner): | |
| def __init__(self): | |
| super().__init__() | |
| self.cfg = parse_configs() | |
| self.timestamp = time.strftime("%Y%m%d-%H%M%S") | |
| self.accelerator = Accelerator( | |
| mixed_precision=self.cfg.train.mixed_precision, | |
| gradient_accumulation_steps=self.cfg.train.accum_steps, | |
| log_with=tuple(self.cfg.logger.trackers), | |
| project_config=ProjectConfiguration( | |
| logging_dir=self.cfg.logger.tracker_root, | |
| ), | |
| use_seedable_sampler=True, | |
| kwargs_handlers=[ | |
| DistributedDataParallelKwargs( | |
| find_unused_parameters=self.cfg.train.find_unused_parameters, | |
| ), | |
| ], | |
| ) | |
| self.accelerator.register_save_state_pre_hook(my_save_pre_hook) # it is the save model hook. | |
| set_seed(self.cfg.experiment.seed, device_specific=True) | |
| with self.accelerator.main_process_first(): | |
| configure_logger( | |
| stream_level=self.cfg.logger.stream_level, | |
| log_level=self.cfg.logger.log_level, | |
| file_path=os.path.join( | |
| self.cfg.logger.log_root, | |
| self.cfg.experiment.parent, self.cfg.experiment.child, | |
| f"{self.timestamp}.log", | |
| ) if self.accelerator.is_main_process else None, | |
| ) | |
| logger.info(self.accelerator.state, main_process_only=False, in_order=True) | |
| configure_dynamo(dict(self.cfg.compile)) | |
| # attributes with defaults | |
| self.model : torch.nn.Module = None | |
| self.optimizer: torch.optim.Optimizer = None | |
| self.scheduler: torch.optim.lr_scheduler.LRScheduler = None | |
| self.train_loader: torch.utils.data.DataLoader = None | |
| self.val_loader: torch.utils.data.DataLoader = None | |
| self.N_max_global_steps: int = None | |
| self.N_global_steps_per_epoch: int = None | |
| self.global_step: int = 0 | |
| self.current_epoch: int = 0 | |
| def __enter__(self): | |
| self.accelerator.init_trackers( | |
| project_name=f"{self.cfg.experiment.parent}/{self.cfg.experiment.child}", | |
| ) | |
| self.prepare_everything() | |
| self.log_inital_info() | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| self.accelerator.end_training() | |
| def control(option: str = None, synchronized: bool = False): | |
| def decorator(func): | |
| def wrapper(self, *args, **kwargs): | |
| if option is None or hasattr(self.accelerator, option): | |
| accelerated_func = getattr(self.accelerator, option)(func) if option is not None else func | |
| result = accelerated_func(self, *args, **kwargs) | |
| if synchronized: | |
| self.accelerator.wait_for_everyone() | |
| return result | |
| else: | |
| raise AttributeError(f"Accelerator has no attribute {option}") | |
| return wrapper | |
| return decorator | |
| def exec_in_order(self): | |
| for rank in range(self.accelerator.num_processes): | |
| try: | |
| if self.accelerator.process_index == rank: | |
| yield | |
| finally: | |
| self.accelerator.wait_for_everyone() | |
| def device(self): | |
| return self.accelerator.device | |
| def is_distributed(self) -> bool: | |
| return self.accelerator.num_processes > 1 | |
| def prepare_everything(self, is_dist_validation: bool = True): | |
| # prepare with accelerator | |
| if is_dist_validation: | |
| self.model, self.optimizer, self.train_loader, self.val_loader = \ | |
| self.accelerator.prepare( | |
| self.model, self.optimizer, self.train_loader, self.val_loader, | |
| ) | |
| else: | |
| self.model, self.optimizer, self.train_loader = \ | |
| self.accelerator.prepare( | |
| self.model, self.optimizer, self.train_loader, | |
| ) | |
| self.accelerator.register_for_checkpointing(self.scheduler) | |
| # prepare stats | |
| N_total_batch_size = self.cfg.train.batch_size * self.accelerator.num_processes * self.cfg.train.accum_steps | |
| self.N_global_steps_per_epoch = math.ceil(len(self.train_loader) / self.cfg.train.accum_steps) | |
| self.N_max_global_steps = self.N_global_steps_per_epoch * self.cfg.train.epochs | |
| if self.cfg.train.debug_global_steps is not None: | |
| logger.warning(f"Overriding max global steps from {self.N_max_global_steps} to {self.cfg.train.debug_global_steps}") | |
| self.N_max_global_steps = self.cfg.train.debug_global_steps | |
| logger.info(f"======== Statistics ========") | |
| logger.info(f"** N_max_global_steps: {self.N_max_global_steps}") | |
| logger.info(f"** N_total_batch_size: {N_total_batch_size}") | |
| logger.info(f"** N_epochs: {self.cfg.train.epochs}") | |
| logger.info(f"** N_global_steps_per_epoch: {self.N_global_steps_per_epoch}") | |
| logger.debug(f"** Prepared loader length: {len(self.train_loader)}") | |
| logger.info(f"** Distributed validation: {is_dist_validation}") | |
| logger.info(f"============================") | |
| logger.info(f"======== Trainable parameters ========") | |
| logger.info(f"** Total: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}") | |
| for sub_name, sub_module in self.accelerator.unwrap_model(self.model).named_children(): | |
| logger.info(f"** {sub_name}: {sum(p.numel() for p in sub_module.parameters() if p.requires_grad)}") | |
| logger.info(f"=====================================") | |
| self.accelerator.wait_for_everyone() | |
| # load checkpoint or model | |
| self.load_ckpt_or_auto_resume_(self.cfg) | |
| # register hooks | |
| self.register_hooks() | |
| def register_hooks(self): | |
| pass | |
| def auto_resume_(self, cfg) -> bool: | |
| ckpt_root = os.path.join( | |
| cfg.saver.checkpoint_root, | |
| cfg.experiment.parent, cfg.experiment.child, | |
| ) | |
| if not os.path.exists(ckpt_root): | |
| return False | |
| ckpt_dirs = os.listdir(ckpt_root) | |
| if len(ckpt_dirs) == 0: | |
| return False | |
| ckpt_dirs.sort() | |
| latest_ckpt = ckpt_dirs[-1] | |
| latest_ckpt_dir = os.path.join(ckpt_root, latest_ckpt) | |
| logger.info(f"======== Auto-resume from {latest_ckpt_dir} ========") | |
| self.accelerator.load_state(latest_ckpt_dir, strict=cfg.saver.load_model_func_kwargs.strict) | |
| self.global_step = int(latest_ckpt) | |
| self.current_epoch = self.global_step // self.N_global_steps_per_epoch | |
| return True | |
| def load_model_(self, cfg): | |
| if cfg.saver.load_model.type == 'hugging_face': | |
| repo_id, file_name = os.path.dirname(cfg.saver.load_model.url), os.path.basename(cfg.saver.load_model.url) | |
| pretrain_model_path = hf_hub_download(repo_id=repo_id, filename=file_name) | |
| logger.info(f"======== Loading pretrain model from hugging face {repo_id, file_name} ========") | |
| safetensors.torch.load_model( | |
| self.accelerator.unwrap_model(self.model), | |
| pretrain_model_path, | |
| **cfg.saver.load_model_func_kwargs | |
| ) | |
| logger.info(f"======== Pretrain Model loaded ========") | |
| return True | |
| else: | |
| logger.info(f"======== Loading model from {cfg.saver.load_model} ========") | |
| safetensors.torch.load_model( | |
| self.accelerator.unwrap_model(self.model), | |
| cfg.saver.load_model, | |
| strict=True, | |
| ) | |
| logger.info(f"======== Model loaded ========") | |
| return True | |
| def load_ckpt_or_auto_resume_(self, cfg): | |
| # auto resume has higher priority, load model from path if auto resume is not available | |
| # cfg.saver.auto_resume and cfg.saver.load_model | |
| if cfg.saver.auto_resume: | |
| successful_resume = self.auto_resume_(cfg) | |
| if successful_resume: | |
| if cfg.saver.load_model: | |
| successful_load = self.load_model_(cfg) | |
| if successful_load: | |
| return | |
| return | |
| if cfg.saver.load_model: | |
| successful_load = self.load_model_(cfg) | |
| if successful_load: | |
| return | |
| logger.debug(f"======== No checkpoint or model is loaded ========") | |
| def save_checkpoint(self): | |
| ckpt_dir = os.path.join( | |
| self.cfg.saver.checkpoint_root, | |
| self.cfg.experiment.parent, self.cfg.experiment.child, | |
| f"{self.global_step:06d}", | |
| ) | |
| self.accelerator.save_state(output_dir=ckpt_dir, safe_serialization=True) | |
| logger.info(f"======== Saved checkpoint at global step {self.global_step} ========") | |
| # manage stratified checkpoints | |
| ckpt_dirs = os.listdir(os.path.dirname(ckpt_dir)) | |
| ckpt_dirs.sort() | |
| max_ckpt = int(ckpt_dirs[-1]) | |
| ckpt_base = int(self.cfg.saver.checkpoint_keep_level) | |
| ckpt_period = self.cfg.saver.checkpoint_global_steps | |
| logger.debug(f"Checkpoint base: {ckpt_base}") | |
| logger.debug(f"Checkpoint period: {ckpt_period}") | |
| cur_order = ckpt_base ** math.floor(math.log(max_ckpt // ckpt_period, ckpt_base)) | |
| cur_idx = 0 | |
| while cur_order > 0: | |
| cur_digit = max_ckpt // ckpt_period // cur_order % ckpt_base | |
| while cur_idx < len(ckpt_dirs) and int(ckpt_dirs[cur_idx]) // ckpt_period // cur_order % ckpt_base < cur_digit: | |
| if int(ckpt_dirs[cur_idx]) // ckpt_period % cur_order != 0: | |
| shutil.rmtree(os.path.join(os.path.dirname(ckpt_dir), ckpt_dirs[cur_idx])) | |
| logger.info(f"Removed checkpoint {ckpt_dirs[cur_idx]}") | |
| cur_idx += 1 | |
| cur_order //= ckpt_base | |
| def global_step_in_epoch(self): | |
| return self.global_step % self.N_global_steps_per_epoch | |
| def _build_model(self): | |
| pass | |
| def _build_optimizer(self): | |
| pass | |
| def _build_scheduler(self): | |
| pass | |
| def _build_dataloader(self): | |
| pass | |
| def _build_loss_fn(self): | |
| pass | |
| def train(self): | |
| pass | |
| def evaluate(self): | |
| pass | |
| def _get_str_progress(epoch: int = None, step: int = None): | |
| if epoch is not None: | |
| log_type = 'epoch' | |
| log_progress = epoch | |
| elif step is not None: | |
| log_type = 'step' | |
| log_progress = step | |
| else: | |
| raise ValueError('Either epoch or step must be provided') | |
| return log_type, log_progress | |
| def log_scalar_kwargs(self, epoch: int = None, step: int = None, split: str = None, **scalar_kwargs): | |
| log_type, log_progress = self._get_str_progress(epoch, step) | |
| split = f'/{split}' if split else '' | |
| for key, value in scalar_kwargs.items(): | |
| self.accelerator.log({f'{key}{split}/{log_type}': value}, log_progress) | |
| def log_images(self, values: dict, step: int | None = None, log_kwargs: dict | None = {}): | |
| for tracker in self.accelerator.trackers: | |
| if hasattr(tracker, 'log_images'): | |
| tracker.log_images(values, step=step, **log_kwargs.get(tracker.name, {})) | |
| def log_optimizer(self, epoch: int = None, step: int = None, attrs: list[str] = [], group_ids: list[int] = []): | |
| log_type, log_progress = self._get_str_progress(epoch, step) | |
| assert self.optimizer is not None, 'Optimizer is not initialized' | |
| if not attrs: | |
| logger.warning('No optimizer attributes are provided, nothing will be logged') | |
| if not group_ids: | |
| logger.warning('No optimizer group ids are provided, nothing will be logged') | |
| for attr in attrs: | |
| assert attr in ['lr', 'momentum', 'weight_decay'], f'Invalid optimizer attribute {attr}' | |
| for group_id in group_ids: | |
| self.accelerator.log({f'opt/{attr}/{group_id}': self.optimizer.param_groups[group_id][attr]}, log_progress) | |
| def log_inital_info(self): | |
| assert self.model is not None, 'Model is not initialized' | |
| assert self.optimizer is not None, 'Optimizer is not initialized' | |
| assert self.scheduler is not None, 'Scheduler is not initialized' | |
| self.accelerator.log({'Config': "```\n" + OmegaConf.to_yaml(self.cfg) + "\n```"}) | |
| self.accelerator.log({'Model': "```\n" + str(self.model) + "\n```"}) | |
| self.accelerator.log({'Optimizer': "```\n" + str(self.optimizer) + "\n```"}) | |
| self.accelerator.log({'Scheduler': "```\n" + str(self.scheduler) + "\n```"}) | |
| def run(self): | |
| self.train() | |