# Copyright (c) 2025 Hanwen Jiang, Xuweiyi Chen. Adapted for WildRayZer from the RayZer project. import torch from transformers import ( get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup, get_linear_schedule_with_warmup, ) import torch.distributed as dist import os from rich import print import traceback from torch.nn.parallel import DistributedDataParallel as DDP import torch.nn.functional as F def bilinear_resize(x, size=None, scale_factor=None, antialias=True): return F.interpolate(x, size=size, scale_factor=scale_factor, antialias=antialias, mode="bilinear", align_corners=False) def print_rank0(*args, **kwargs): if dist.is_initialized(): if dist.get_rank() == 0: print(*args, **kwargs) else: print(*args, **kwargs) def format_number(num): if num >= 1_000_000_000: return f"{num / 1_000_000_000:.2f}B" elif num >= 1_000_000: return f"{num / 1_000_000:.2f}M" elif num >= 1_000: return f"{num / 1_000:.2f}K" return str(num) def create_optimizer(model, weight_decay, learning_rate, betas): # start with all of the candidate parameters all_param_dict = {name: param for name, param in model.named_parameters()} # filter out those that do not require grad optimized_param_dict = {name: param for name, param in all_param_dict.items() if param.requires_grad} decay_params, nodecay_params = [], [] for name, param in optimized_param_dict.items(): if param.dim() == 1 or getattr(param, '_no_weight_decay', False): nodecay_params.append(param) else: decay_params.append(param) optim_groups = [ {'params': decay_params, 'weight_decay': weight_decay}, {'params': nodecay_params, 'weight_decay': 0.0} ] # use fused AdamW optimizer by default. optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas,fused=True) # Print Model Information if dist.get_rank() == 0: def get_module_name(name): parts = name.split('.') if len(parts) > 2 and parts[0] == 'module': return parts[1] + '.' + parts[2] return parts[0] # Fallback to first part if no 'module.' prefix print(f'Optimizer: AdamW, learning rate: {learning_rate}, weight decay: {weight_decay}, betas: {betas}') # Number of parameters total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in optimized_param_dict.values()) optim_module_names = sorted(set(get_module_name(name) for name in optimized_param_dict.keys())) frozen_module_names = sorted(set(get_module_name(name) for name in set(all_param_dict.keys()) - set(optimized_param_dict.keys()))) print(f'Total parameters: {format_number(total_params)}, Trainable parameters: {format_number(trainable_params)}') print(f'Optimized parameters: {optim_module_names}') print(f'Frozen parameters: {frozen_module_names}') return optimizer, optimized_param_dict, all_param_dict def create_lr_scheduler(optimizer, param_update_steps, warm_up_steps, scheduler_type='cosine'): if scheduler_type == 'linear': scheduler = get_linear_schedule_with_warmup(optimizer, warm_up_steps, param_update_steps) elif scheduler_type == 'cosine': scheduler = get_cosine_schedule_with_warmup(optimizer, warm_up_steps, param_update_steps) elif scheduler_type == 'constant': scheduler = get_constant_schedule_with_warmup(optimizer, warm_up_steps) else: raise ValueError(f'Invalid scheduler type: {scheduler_type}') return scheduler def find_checkpoints(load_path): if os.path.isdir(load_path): ckpt_names = [file_name for file_name in os.listdir(load_path) if file_name.endswith(".pt")] ckpt_names = sorted(ckpt_names, key=lambda x: x) ckpt_paths = [os.path.join(load_path, ckpt_name) for ckpt_name in ckpt_names] else: if load_path.endswith(".pt"): ckpt_paths = [load_path] else: ckpt_paths = [] return ckpt_paths def auto_resume_job( load_path, model, optimizer, lr_scheduler, reset_training_state ): """ Resume training from the latest checkpoint in the specified directory. Returns the fwdbwd_pass_step and param_update_step. Args: load_path: If dir, load the last checkpoint in the directory. O.w., assume it's a ckpt and load it. model: model to be loaded optimizer: optimizer to be loaded lr_scheduler: lr scheduler to be loaded reset_training_state: whether to reset the training state Returns: optimizer, lr_scheduler, forward_pass_step, param_update_step """ forward_pass_step = 0 param_update_step = 0 all_ckpt_paths = find_checkpoints(load_path) if len(all_ckpt_paths) == 0: print_rank0(f"No checkpoint found in {load_path}, we will start from scratch") return optimizer, lr_scheduler, forward_pass_step, param_update_step try: ckpt_path = all_ckpt_paths[-1] checkpoint = torch.load(ckpt_path, map_location="cpu") except: traceback.print_exc() print_rank0(f"Failed to load {ckpt_path}, we will start from scratch") return optimizer, lr_scheduler, forward_pass_step, param_update_step # Load model weights if isinstance(model, DDP): status = model.module.load_state_dict(checkpoint['model'], strict=False) else: status = model.load_state_dict(checkpoint['model'], strict=False) print_rank0(f"Loaded model from {os.path.abspath(ckpt_path)}, the status is {status}") # resume training state if not reset_training_state: try: optimizer.load_state_dict(checkpoint["optimizer"]) lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) forward_pass_step = checkpoint["fwdbwd_pass_step"] param_update_step = checkpoint["param_update_step"] print_rank0(f"Resumed optimizer and lr_scheduler from {ckpt_path}") except: traceback.print_exc() print_rank0(f"Failed to load optimizer and lr_scheduler from {ckpt_path}") return optimizer, lr_scheduler, forward_pass_step, param_update_step