| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| The main entry point to run the PPO algorithm |
| """ |
|
|
| import os |
| import logging |
| import warnings |
| import ray |
| import torch |
| import torch.distributed |
| from omegaconf import DictConfig, open_dict |
| from transformers import AutoModelForCausalLM |
|
|
| from verl.single_controller.base import Worker |
| from verl.single_controller.base.decorator import register, Dispatch |
| import verl.utils.torch_functional as verl_F |
| from verl import DataProto |
| from verl.utils.model import compute_position_id_with_mask |
| from verl.utils.fs import copy_local_path_from_hdfs |
| from verl.utils.fsdp_utils import get_fsdp_wrap_policy, load_fsdp_grad, offload_fsdp_grad, init_fn, get_init_weight_context_manager, get_fsdp_wrap_policy_vla |
| from verl.utils.fsdp_utils import offload_fsdp_optimizer, offload_fsdp_param_and_grad, load_fsdp_optimizer, load_fsdp_param_and_grad |
| from verl.utils.import_utils import import_external_libs |
| from verl.utils.debug import log_gpu_memory_usage |
| import verl.utils.hdfs_io as hdfs_io |
| from verl.utils import hf_tokenizer |
| from ..trainer.ppo import core_algos |
| from verl.utils.py_functional import append_to_dict |
| from codetiming import Timer |
|
|
|
|
| from verl.utils.openvla_utils import update_auto_map , check_model_logic_mismatch |
| from peft import LoraConfig, PeftModel, get_peft_model, TaskType |
| import json |
|
|
|
|
| logger = logging.getLogger(__file__) |
| logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN')) |
|
|
| def convert_to_regular_types(obj): |
| """Convert Hydra configs and other special types to regular Python types.""" |
| from omegaconf import ListConfig, DictConfig |
| if isinstance(obj, (ListConfig, DictConfig)): |
| return {k: convert_to_regular_types(v) for k, v in obj.items()} if isinstance(obj, DictConfig) else list(obj) |
| elif isinstance(obj, (list, tuple)): |
| return [convert_to_regular_types(x) for x in obj] |
| elif isinstance(obj, dict): |
| return {k: convert_to_regular_types(v) for k, v in obj.items()} |
| return obj |
|
|
|
|
| class RobActorRolloutRefWorker(Worker): |
| """ |
| This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy |
| or a hybrid engine based on the config.rollout |
| """ |
|
|
| def __init__(self, config: DictConfig, role: str): |
| super().__init__() |
| self.config = config |
| import torch.distributed |
| if not torch.distributed.is_initialized(): |
| torch.distributed.init_process_group(backend="nccl") |
|
|
| |
| world_size = torch.distributed.get_world_size() |
| from torch.distributed.device_mesh import init_device_mesh |
| |
| self.device_mesh = init_device_mesh('cuda', mesh_shape=(world_size,), mesh_dim_names=['fsdp']) |
|
|
| self._is_lora = self.config.model.get('lora_rank', 0) > 0 |
| self.role = role |
| assert self.role in ['actor', 'rollout', 'ref', 'actor_rollout', 'actor_rollout_ref'] |
|
|
| self._is_actor = self.role in ['actor', 'actor_rollout', 'actor_rollout_ref'] |
| self._is_rollout = self.role in ['rollout', 'actor_rollout', 'actor_rollout_ref'] |
| self._is_ref = self.role in ['ref', 'actor_rollout_ref'] |
|
|
| self._is_offload_param = False |
| self._is_offload_grad = False |
| self._is_offload_optimizer = False |
| if self._is_actor: |
| self._is_offload_param = self.config.actor.fsdp_config.get('param_offload', False) |
| self._is_offload_grad = self.config.actor.fsdp_config.get('grad_offload', False) |
| self._is_offload_optimizer = self.config.actor.fsdp_config.get('optimizer_offload', False) |
| elif self._is_ref: |
| |
| self._is_offload_param = self.config.ref.fsdp_config.get('param_offload', False) |
|
|
| |
| if self._is_actor: |
| self.config.actor.ppo_mini_batch_size //= self.device_mesh.shape[0] |
| self.config.actor.ppo_micro_batch_size //= self.device_mesh.shape[0] |
| if self._is_rollout: |
| self.config.rollout.log_prob_micro_batch_size //= self.device_mesh.shape[0] |
| if self._is_ref: |
| self.config.ref.log_prob_micro_batch_size //= self.device_mesh.shape[0] |
|
|
| def _build_model_optimizer(self, |
| model_path, |
| fsdp_config, |
| optim_config, |
| override_model_config, |
| enable_gradient_checkpointing=False, |
| trust_remote_code=False): |
| from verl.utils.model import print_model_size, update_model_config |
| from verl.utils.torch_dtypes import PrecisionType |
| from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, \ |
| CPUOffload |
| from torch import optim |
|
|
| log_gpu_memory_usage('Before init from HF AutoModel', logger=logger) |
| local_path = copy_local_path_from_hdfs(model_path) |
| |
| |
| if self.config.model.vla == "openvla-oft": |
| from verl.utils.vla_utils.openvla_oft.configuration_prismatic import OpenVLAConfig |
| from verl.utils.vla_utils.openvla_oft.modeling_prismatic import OpenVLAForActionPrediction |
| from verl.utils.vla_utils.openvla_oft.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor |
| |
| AutoConfig.register("openvla", OpenVLAConfig) |
| AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor) |
| AutoProcessor.register(OpenVLAConfig, PrismaticProcessor) |
| AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction) |
| if self.rank == 0: |
| update_auto_map(local_path) |
| check_model_logic_mismatch(local_path) |
| torch.distributed.barrier() |
| |
| elif self.config.model.vla == "openvla": |
| from verl.utils.vla_utils.openvla.configuration_prismatic import OpenVLAConfig |
| from verl.utils.vla_utils.openvla.modeling_prismatic import OpenVLAForActionPrediction |
| from verl.utils.vla_utils.openvla.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor |
| AutoConfig.register("openvla", OpenVLAConfig) |
| AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor) |
| AutoProcessor.register(OpenVLAConfig, PrismaticProcessor) |
| AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction) |
| if self.rank == 0: |
| update_auto_map(local_path) |
| check_model_logic_mismatch(local_path) |
| torch.distributed.barrier() |
| |
| |
|
|
| |
| |
| self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code, model = self.config.model.vla) |
|
|
| torch_dtype = fsdp_config.get('model_dtype', None) |
| if torch_dtype is None: |
| torch_dtype = torch.float32 if self._is_actor else torch.bfloat16 |
| else: |
| torch_dtype = PrecisionType.to_dtype(torch_dtype) |
|
|
| |
| actor_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code) |
| if self.config.model.use_remove_padding: |
| from verl.models.registry import check_model_support_rmpad |
| check_model_support_rmpad(actor_model_config.model_type) |
| override_config_kwargs = { |
| 'bos_token_id': self.tokenizer.bos_token_id, |
| 'eos_token_id': self.tokenizer.eos_token_id, |
| 'pad_token_id': self.tokenizer.pad_token_id, |
| } |
| override_config_kwargs.update(override_model_config) |
| update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs) |
| if self.rank == 0: |
| print(f'Model config after override: {actor_model_config}') |
|
|
| |
| init_context = get_init_weight_context_manager(use_meta_tensor=not actor_model_config.tie_word_embeddings) |
|
|
| with init_context(), warnings.catch_warnings(): |
| warnings.simplefilter("ignore") |
| if self.config.model.vla == "openvla-oft": |
| actor_module = AutoModelForVision2Seq.from_pretrained( |
| pretrained_model_name_or_path=local_path, |
| torch_dtype=torch_dtype, |
| |
| config=actor_model_config, |
| trust_remote_code=True, |
| ) |
| |
| actor_module.vision_backbone.set_num_images_in_input(self.config.actor.num_images_in_input) |
| |
| dataset_statistics_path = os.path.join(local_path, "dataset_statistics.json") |
| if os.path.isfile(dataset_statistics_path): |
| with open(dataset_statistics_path, "r") as f: |
| norm_stats = json.load(f) |
| actor_module.norm_stats = norm_stats |
| else: |
| print( |
| "WARNING: No local dataset_statistics.json file found for current checkpoint.\n" |
| "You can ignore this if you are loading the base VLA (i.e. not fine-tuned) checkpoint." |
| "Otherwise, you may run into errors when trying to call `predict_action()` due to an absent `unnorm_key`." |
| ) |
| elif self.config.model.vla == "openvla": |
| actor_module = AutoModelForVision2Seq.from_pretrained( |
| pretrained_model_name_or_path=local_path, |
| torch_dtype=torch_dtype, |
| attn_implementation="flash_attention_2", |
| config=actor_model_config, |
| trust_remote_code=True, |
| ) |
| |
| actor_module.to(torch_dtype) |
|
|
| if enable_gradient_checkpointing: |
| actor_module.gradient_checkpointing_enable() |
| |
| if self._is_lora: |
| print("Applying LoRA to actor module") |
| |
| lora_config = { |
| |
| 'r': self.config.model.lora_rank, |
| 'lora_alpha': self.config.model.lora_alpha, |
| "lora_dropout": 0 , |
| 'target_modules': convert_to_regular_types(self.config.model.target_modules), |
| 'init_lora_weights': "gaussian" |
| } |
| actor_module = get_peft_model(actor_module, LoraConfig(**lora_config)) |
| actor_module.print_trainable_parameters() |
| |
| |
| |
| torch.distributed.barrier() |
|
|
| if self.rank == 0: |
| print_model_size(actor_module) |
|
|
| log_gpu_memory_usage('After init from HF AutoModel', logger=logger) |
|
|
| |
| mixed_precision_config = fsdp_config.get('mixed_precision', None) |
| if mixed_precision_config is not None: |
| param_dtype = PrecisionType.to_dtype(mixed_precision_config.get('param_dtype', 'bf16')) |
| reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get('reduce_dtype', 'fp32')) |
| buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get('buffer_dtype', 'fp32')) |
| else: |
| param_dtype = torch.bfloat16 |
| reduce_dtype = torch.float32 |
| buffer_dtype = torch.float32 |
|
|
| mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype) |
|
|
| if self._is_ref: |
| mixed_precision = None |
| |
| |
| auto_wrap_policy = get_fsdp_wrap_policy_vla(module=actor_module, config=fsdp_config.get('wrap_policy', None), is_lora=self.config.model.get('lora_rank', 0) > 0) |
| |
| |
|
|
| print(f'wrap_policy: {auto_wrap_policy}') |
|
|
| |
| if auto_wrap_policy is None: |
| sharding_strategy = ShardingStrategy.SHARD_GRAD_OP |
| else: |
| sharding_strategy = ShardingStrategy.FULL_SHARD |
|
|
| |
| actor_module_fsdp = FSDP( |
| actor_module, |
| param_init_fn=init_fn, |
| use_orig_params=False, |
| auto_wrap_policy=auto_wrap_policy, |
| device_id=torch.cuda.current_device(), |
| sharding_strategy=sharding_strategy, |
| mixed_precision=mixed_precision, |
| sync_module_states=True, |
| device_mesh=self.device_mesh) |
|
|
| log_gpu_memory_usage('After Actor FSDP init', logger=logger) |
|
|
| |
| if self._is_actor: |
| from verl.utils.torch_functional import get_constant_schedule_with_warmup |
| actor_optimizer = optim.AdamW(actor_module_fsdp.parameters(), |
| lr=optim_config.lr, |
| betas=optim_config.get('betas', (0.9, 0.999)), |
| weight_decay=optim_config.get('weight_decay', 1e-2)) |
|
|
| total_steps = optim_config.get('total_training_steps', 0) |
| num_warmup_steps_ratio = optim_config.get('lr_warmup_steps_ratio', 0.) |
| num_warmup_steps = int(num_warmup_steps_ratio * total_steps) |
|
|
| print(f'Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}') |
|
|
| actor_lr_scheduler = get_constant_schedule_with_warmup(optimizer=actor_optimizer, |
| num_warmup_steps=num_warmup_steps) |
| else: |
| actor_optimizer = None |
| actor_lr_scheduler = None |
|
|
| log_gpu_memory_usage('After actor optimizer init', logger=logger) |
|
|
| return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config |
|
|
| def _build_rollout(self): |
| if self.config.rollout.name == 'hf': |
| from verl.workers.rollout import RobHFRollout |
| from verl.workers.hybrid_engine import BaseShardingManager |
| rollout = RobHFRollout(module=self.actor_module_fsdp, config=self.config.rollout) |
| sharding_manager = BaseShardingManager() |
| |
| elif self.config.rollout.name == 'vllm': |
| raise ValueError |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| return rollout, sharding_manager |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def init_model(self): |
| from verl.workers.actor import RobDataParallelPPOActor |
| |
| import_external_libs(self.config.model.get('external_lib', None)) |
|
|
| from omegaconf import OmegaConf |
| override_model_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create())) |
|
|
| if self._is_actor or self._is_rollout: |
| |
| if self._is_actor: |
| optim_config = self.config.actor.optim |
| fsdp_config = self.config.actor.fsdp_config |
| else: |
| optim_config = None |
| fsdp_config = OmegaConf.create() |
| self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = self._build_model_optimizer( |
| model_path=self.config.model.path, |
| fsdp_config=fsdp_config, |
| optim_config=optim_config, |
| override_model_config=override_model_config, |
| enable_gradient_checkpointing=self.config.model.get('enable_gradient_checkpointing', False), |
| trust_remote_code=True) |
|
|
| |
| self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module |
|
|
| if self._is_offload_param: |
| |
| offload_fsdp_grad(module=self.actor_module_fsdp) |
| log_gpu_memory_usage('After offload actor grad during init', logger=logger) |
| if self._is_offload_optimizer: |
| offload_fsdp_optimizer(optimizer=self.actor_optimizer) |
| log_gpu_memory_usage('After offload actor optimizer during init', logger=logger) |
| |
| if self._is_actor: |
| OmegaConf.set_struct(self.config.actor, True) |
| self.actor = RobDataParallelPPOActor(config=self.config.actor, |
| actor_module=self.actor_module_fsdp, |
| actor_optimizer=self.actor_optimizer) |
|
|
| if self._is_rollout: |
| self.rollout, self.sharding_manager = self._build_rollout() |
|
|
| if self._is_ref: |
| self.ref_module_fsdp = self._build_model_optimizer(model_path=self.config.model.path, |
| fsdp_config=self.config.ref.fsdp_config, |
| optim_config=None, |
| override_model_config=override_model_config, |
| trust_remote_code=True)[0] |
| |
| if self._is_offload_param: |
| offload_fsdp_param_and_grad(module=self.ref_module_fsdp, offload_grad=self._is_offload_grad) |
|
|
| OmegaConf.set_struct(self.config.ref, True) |
| self.ref_policy = RobDataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp) |
|
|
| torch.cuda.synchronize() |
| torch.distributed.barrier() |
| torch.cuda.empty_cache() |
|
|
| @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) |
| def update_actor(self, data: DataProto): |
| |
|
|
| assert self._is_actor |
| if self._is_offload_param: |
| load_fsdp_param_and_grad(module=self.actor_module_fsdp, |
| device_id=torch.cuda.current_device(), |
| load_grad=self._is_offload_grad) |
| if self._is_offload_optimizer: |
| load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=torch.cuda.current_device()) |
|
|
| |
|
|
| log_gpu_memory_usage('Before update policy', logger=logger) |
|
|
| metrics = self.actor.update_policy(data=data) |
|
|
| self.actor_lr_scheduler.step() |
| lr = self.actor_lr_scheduler.get_last_lr()[0] |
| metrics['actor/lr(1e-4)'] = lr * 1e4 |
|
|
| log_gpu_memory_usage('After update policy', logger=logger) |
|
|
| |
| output = DataProto(meta_info={'metrics': metrics}) |
| output = output.to('cpu') |
|
|
| if self._is_offload_param: |
| offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad) |
| if self._is_offload_optimizer: |
| offload_fsdp_optimizer(optimizer=self.actor_optimizer) |
| torch.cuda.synchronize() |
| torch.distributed.barrier() |
| torch.cuda.empty_cache() |
| return output |
|
|
| @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) |
| def compute_entropy(self, data: DataProto): |
| |
| data = data.to('cuda') |
|
|
| assert self._is_actor |
| if self._is_offload_param: |
| load_fsdp_param_and_grad(module=self.actor_module_fsdp, |
| device_id=torch.cuda.current_device(), |
| load_grad=self._is_offload_grad) |
|
|
| data.batch = data.batch.cuda() |
|
|
| log_gpu_memory_usage('Before compute entropy', logger=logger) |
|
|
| metrics = self.actor.compute_entropy(bacth_data=data) |
|
|
| log_gpu_memory_usage('After compute entropy', logger=logger) |
|
|
| |
| output = DataProto(meta_info={'metrics': metrics}) |
| output = output.to('cpu') |
| |
| if self._is_offload_param: |
| offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad) |
| if self._is_offload_optimizer: |
| offload_fsdp_optimizer(optimizer=self.actor_optimizer) |
| torch.cuda.synchronize() |
| torch.distributed.barrier() |
| torch.cuda.empty_cache() |
| return output |
|
|
| @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) |
| def generate_sequences(self, prompts): |
| prompts = prompts.to('cuda') |
| |
| recompute_log_prob = prompts.meta_info.get('recompute_log_prob', True) |
|
|
| assert self._is_rollout |
| if self._is_offload_param: |
| load_fsdp_param_and_grad(module=self.actor_module_fsdp, |
| device_id=torch.cuda.current_device(), |
| load_grad=self._is_offload_grad) |
|
|
| prompts.batch = prompts.batch.cuda() |
| meta_info = {'eos_token_id': self.tokenizer.eos_token_id, 'pad_token_id': self.tokenizer.pad_token_id} |
| prompts.meta_info.update(meta_info) |
| |
| |
| |
| |
| |
| with self.sharding_manager: |
| log_gpu_memory_usage('After entering sharding manager', logger=logger) |
| prompts = self.sharding_manager.preprocess_data(prompts) |
| output = self.rollout.generate_sequences(prompts=prompts) |
| log_gpu_memory_usage('After rollout generation', logger=logger) |
|
|
| output = self.sharding_manager.postprocess_data(output) |
| torch.cuda.synchronize() |
|
|
| |
| |
| |
| if self._is_actor and recompute_log_prob: |
| |
| |
| output.meta_info['micro_batch_size'] = self.config.rollout.log_prob_micro_batch_size |
| output.meta_info['temperature'] = self.config.rollout.temperature |
| output.meta_info['use_dynamic_bsz'] = self.config.rollout.log_prob_use_dynamic_bsz |
| output.meta_info['max_token_len'] = self.config.rollout.log_prob_max_token_len_per_gpu |
| output.meta_info['pad_token_id'] = self.tokenizer.pad_token_id |
| old_log_probs = self.actor.compute_log_prob(data=output) |
| output.batch['old_log_probs'] = old_log_probs |
|
|
| output = output.to('cpu') |
|
|
| if self._is_offload_param: |
| |
| offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad) |
| |
| torch.cuda.synchronize() |
| torch.distributed.barrier() |
| torch.cuda.empty_cache() |
| log_gpu_memory_usage('After recompute log prob', logger=logger) |
| return output |
|
|
| @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) |
| def compute_ref_log_prob(self, data: DataProto): |
| assert self._is_ref |
|
|
| data = data.to('cuda') |
|
|
| if self._is_offload_param: |
| load_fsdp_param_and_grad(module=self.ref_module_fsdp, |
| device_id=torch.cuda.current_device(), |
| load_grad=self._is_offload_grad) |
|
|
| micro_batch_size = self.config.ref.log_prob_micro_batch_size |
| data.meta_info['micro_batch_size'] = micro_batch_size |
| data.meta_info['temperature'] = self.config.rollout.temperature |
| data.meta_info['max_token_len'] = self.config.ref.log_prob_max_token_len_per_gpu |
| data.meta_info['use_dynamic_bsz'] = self.config.ref.log_prob_use_dynamic_bsz |
| data.meta_info['pad_token_id'] = self.tokenizer.pad_token_id |
| output = self.ref_policy.compute_log_prob(data=data) |
| output = DataProto.from_dict(tensors={'ref_log_prob': output}) |
|
|
| output = output.to('cpu') |
|
|
| if self._is_offload_param: |
| offload_fsdp_param_and_grad(module=self.ref_module_fsdp, offload_grad=self._is_offload_grad) |
| torch.cuda.synchronize() |
| torch.distributed.barrier() |
| torch.cuda.empty_cache() |
| return output |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def save_checkpoint(self, local_path, hdfs_path=None): |
| assert self._is_actor |
| |
| import torch.distributed as dist |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| from peft import PeftModel |
| import transformers |
| |
| if self._is_offload_param: |
| load_fsdp_param_and_grad(module=self.actor_module_fsdp, |
| device_id=torch.cuda.current_device(), |
| load_grad=self._is_offload_grad) |
|
|
| |
| if self._is_lora and isinstance(self.actor_module, PeftModel): |
| if dist.get_rank() == 0: |
| os.makedirs(local_path, exist_ok=True) |
|
|
| lora_save_path = os.path.join(local_path, "lora_adapter") |
|
|
| if isinstance(self.actor_module_fsdp, FSDP): |
| with FSDP.summon_full_params(self.actor_module_fsdp, writeback=False, offload_to_cpu=True): |
| if dist.get_rank() == 0: |
| from typing import OrderedDict |
| lora_params = OrderedDict() |
| model = self.actor_module_fsdp._fsdp_wrapped_module.base_model.model |
| for name, param in model.named_parameters(): |
| if ".lora_" in name: |
| name = "base_model.model." + name.replace("._fsdp_wrapped_module.", ".") |
| lora_params[name] = param |
| self.actor_module_fsdp.save_pretrained( |
| lora_save_path, |
| state_dict=lora_params, |
| safe_serialization=True |
| ) |
| else: |
| self.actor_module.save_pretrained(lora_save_path, safe_serialization=True) |
|
|
| dist.barrier() |
| if dist.get_rank() == 0: |
| print(f"[rank-{self.rank}]: Saved LoRA adapter to: {lora_save_path}") |
| |
| |
| base_vla = AutoModelForVision2Seq.from_pretrained( |
| self.config.model.path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True, device_map="cpu" |
| ) |
| merged_vla = PeftModel.from_pretrained(base_vla, lora_save_path) |
| merged_vla = merged_vla.merge_and_unload() |
|
|
| if dist.get_rank() == 0: |
| merged_vla.save_pretrained(local_path) |
| print(f"Saved merged model at: {local_path}") |
|
|
| |
| dist.barrier() |
| |
| |
| |
| else: |
| import torch.distributed |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType, FullStateDictConfig |
| cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) |
| with FSDP.state_dict_type(self.actor.actor_module, StateDictType.FULL_STATE_DICT, cfg): |
| state_dict = self.actor.actor_module.state_dict() |
| if self.rank == 0: |
| print(f'Saving actor checkpoint to {local_path}') |
| os.makedirs(local_path, exist_ok=True) |
| self.actor_module.save_pretrained(local_path, state_dict=state_dict) |
| self.tokenizer.save_pretrained(local_path) |
| if hdfs_path is not None: |
| print(f'Uploading actor checkpoint to {hdfs_path}') |
| hdfs_io.makedirs(hdfs_path, exist_ok=True) |
| hdfs_io.copy(src=local_path, dst=hdfs_path) |
|
|
| torch.distributed.barrier() |
| if self._is_offload_param: |
| offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad) |
|
|
|
|
| class ActorRolloutRefWorker(Worker): |
| """ |
| This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy |
| or a hybrid engine based on the config.rollout |
| """ |
|
|
| def __init__(self, config: DictConfig, role: str): |
| super().__init__() |
| self.config = config |
| import torch.distributed |
| if not torch.distributed.is_initialized(): |
| torch.distributed.init_process_group(backend="nccl") |
|
|
| |
| world_size = torch.distributed.get_world_size() |
| from torch.distributed.device_mesh import init_device_mesh |
| |
| self.device_mesh = init_device_mesh('cuda', mesh_shape=(world_size,), mesh_dim_names=['fsdp']) |
|
|
| self.role = role |
| assert self.role in ['actor', 'rollout', 'ref', 'actor_rollout', 'actor_rollout_ref'] |
|
|
| self._is_actor = self.role in ['actor', 'actor_rollout', 'actor_rollout_ref'] |
| self._is_rollout = self.role in ['rollout', 'actor_rollout', 'actor_rollout_ref'] |
| self._is_ref = self.role in ['ref', 'actor_rollout_ref'] |
|
|
| self._is_offload_param = False |
| self._is_offload_grad = False |
| self._is_offload_optimizer = False |
| if self._is_actor: |
| self._is_offload_param = self.config.actor.fsdp_config.get('param_offload', False) |
| self._is_offload_grad = self.config.actor.fsdp_config.get('grad_offload', False) |
| self._is_offload_optimizer = self.config.actor.fsdp_config.get('optimizer_offload', False) |
| elif self._is_ref: |
| |
| self._is_offload_param = self.config.ref.fsdp_config.get('param_offload', False) |
|
|
| |
| if self._is_actor: |
| self.config.actor.ppo_mini_batch_size //= self.device_mesh.shape[0] |
| self.config.actor.ppo_micro_batch_size //= self.device_mesh.shape[0] |
| if self._is_rollout: |
| self.config.rollout.log_prob_micro_batch_size //= self.device_mesh.shape[0] |
| if self._is_ref: |
| self.config.ref.log_prob_micro_batch_size //= self.device_mesh.shape[0] |
|
|
| def _build_model_optimizer(self, |
| model_path, |
| fsdp_config, |
| optim_config, |
| override_model_config, |
| enable_gradient_checkpointing=False, |
| trust_remote_code=False): |
| from verl.utils.model import print_model_size, update_model_config |
| from verl.utils.torch_dtypes import PrecisionType |
| from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, \ |
| CPUOffload |
| from torch import optim |
|
|
| log_gpu_memory_usage('Before init from HF AutoModel', logger=logger) |
| local_path = copy_local_path_from_hdfs(model_path) |
|
|
| |
| |
| self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) |
|
|
| torch_dtype = fsdp_config.get('model_dtype', None) |
| if torch_dtype is None: |
| torch_dtype = torch.float32 if self._is_actor else torch.bfloat16 |
| else: |
| torch_dtype = PrecisionType.to_dtype(torch_dtype) |
|
|
| |
| actor_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code) |
| if self.config.model.use_remove_padding: |
| from verl.models.registry import check_model_support_rmpad |
| check_model_support_rmpad(actor_model_config.model_type) |
| override_config_kwargs = { |
| 'bos_token_id': self.tokenizer.bos_token_id, |
| 'eos_token_id': self.tokenizer.eos_token_id, |
| 'pad_token_id': self.tokenizer.pad_token_id, |
| } |
| override_config_kwargs.update(override_model_config) |
| update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs) |
| if self.rank == 0: |
| print(f'Model config after override: {actor_model_config}') |
|
|
| |
| init_context = get_init_weight_context_manager(use_meta_tensor=not actor_model_config.tie_word_embeddings) |
|
|
| with init_context(), warnings.catch_warnings(): |
| warnings.simplefilter("ignore") |
| from liger_kernel.transformers import AutoLigerKernelForCausalLM |
| actor_module = AutoLigerKernelForCausalLM.from_pretrained(pretrained_model_name_or_path=local_path, |
| torch_dtype=torch_dtype, |
| config=actor_model_config, |
| attn_implementation='flash_attention_2', |
| trust_remote_code=trust_remote_code) |
| |
| actor_module.to(torch_dtype) |
|
|
| if enable_gradient_checkpointing: |
| actor_module.gradient_checkpointing_enable() |
| torch.distributed.barrier() |
|
|
| if self.rank == 0: |
| print_model_size(actor_module) |
|
|
| log_gpu_memory_usage('After init from HF AutoModel', logger=logger) |
|
|
| |
| mixed_precision_config = fsdp_config.get('mixed_precision', None) |
| if mixed_precision_config is not None: |
| param_dtype = PrecisionType.to_dtype(mixed_precision_config.get('param_dtype', 'bf16')) |
| reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get('reduce_dtype', 'fp32')) |
| buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get('buffer_dtype', 'fp32')) |
| else: |
| param_dtype = torch.bfloat16 |
| reduce_dtype = torch.float32 |
| buffer_dtype = torch.float32 |
|
|
| mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype) |
|
|
| if self._is_ref: |
| mixed_precision = None |
|
|
| auto_wrap_policy = get_fsdp_wrap_policy(module=actor_module, config=fsdp_config.get('wrap_policy', None)) |
|
|
| if self._is_rollout and self.config.rollout.name == 'hf': |
| |
| auto_wrap_policy = None |
|
|
| print(f'wrap_policy: {auto_wrap_policy}') |
|
|
| |
| if auto_wrap_policy is None: |
| sharding_strategy = ShardingStrategy.SHARD_GRAD_OP |
| else: |
| sharding_strategy = ShardingStrategy.FULL_SHARD |
|
|
| |
| actor_module_fsdp = FSDP( |
| actor_module, |
| param_init_fn=init_fn, |
| use_orig_params=False, |
| auto_wrap_policy=auto_wrap_policy, |
| device_id=torch.cuda.current_device(), |
| sharding_strategy=sharding_strategy, |
| mixed_precision=mixed_precision, |
| sync_module_states=True, |
| device_mesh=self.device_mesh) |
|
|
| log_gpu_memory_usage('After Actor FSDP init', logger=logger) |
|
|
| |
| if self._is_actor: |
| from verl.utils.torch_functional import get_constant_schedule_with_warmup |
| actor_optimizer = optim.AdamW(actor_module_fsdp.parameters(), |
| lr=optim_config.lr, |
| betas=optim_config.get('betas', (0.9, 0.999)), |
| weight_decay=optim_config.get('weight_decay', 1e-2)) |
|
|
| total_steps = optim_config.get('total_training_steps', 0) |
| num_warmup_steps_ratio = optim_config.get('lr_warmup_steps_ratio', 0.) |
| num_warmup_steps = int(num_warmup_steps_ratio * total_steps) |
|
|
| print(f'Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}') |
|
|
| actor_lr_scheduler = get_constant_schedule_with_warmup(optimizer=actor_optimizer, |
| num_warmup_steps=num_warmup_steps) |
| else: |
| actor_optimizer = None |
| actor_lr_scheduler = None |
|
|
| log_gpu_memory_usage('After actor optimizer init', logger=logger) |
|
|
| return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config |
|
|
| def _build_rollout(self): |
| if self.config.rollout.name == 'hf': |
| from verl.workers.rollout import HFRollout |
| from verl.workers.hybrid_engine import BaseShardingManager |
| rollout = HFRollout(module=self.actor_module_fsdp, config=self.config.rollout) |
| sharding_manager = BaseShardingManager() |
| |
| elif self.config.rollout.name == 'vllm': |
| from verl.workers.rollout.vllm_rollout import vLLMRollout |
| from verl.workers.hybrid_engine import FSDPVLLMShardingManager |
| log_gpu_memory_usage('Before building vllm rollout', logger=None) |
| rollout = vLLMRollout(actor_module=self.actor_module_fsdp, |
| config=self.config.rollout, |
| tokenizer=self.tokenizer, |
| model_hf_config=self.actor_model_config) |
| log_gpu_memory_usage('After building vllm rollout', logger=None) |
| if torch.distributed.get_world_size() == 1: |
| self.config.rollout.load_format = 'dummy_hf' |
| sharding_manager = FSDPVLLMShardingManager(module=self.actor_module_fsdp, |
| inference_engine=rollout.inference_engine, |
| model_config=self.actor_model_config, |
| full_params='hf' in self.config.rollout.load_format) |
| log_gpu_memory_usage('After building sharding manager', logger=None) |
|
|
| return rollout, sharding_manager |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def init_model(self): |
| from verl.workers.actor import DataParallelPPOActor |
| |
| import_external_libs(self.config.model.get('external_lib', None)) |
|
|
| from omegaconf import OmegaConf |
| override_model_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create())) |
|
|
| if self._is_actor or self._is_rollout: |
| |
| if self._is_actor: |
| optim_config = self.config.actor.optim |
| fsdp_config = self.config.actor.fsdp_config |
| else: |
| optim_config = None |
| fsdp_config = OmegaConf.create() |
| self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = self._build_model_optimizer( |
| model_path=self.config.model.path, |
| fsdp_config=fsdp_config, |
| optim_config=optim_config, |
| override_model_config=override_model_config, |
| enable_gradient_checkpointing=self.config.model.get('enable_gradient_checkpointing', False), |
| trust_remote_code=self.config.model.get('trust_remote_code', False)) |
|
|
| |
| self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module |
|
|
| if self._is_offload_param: |
| |
| offload_fsdp_grad(module=self.actor_module_fsdp) |
| log_gpu_memory_usage('After offload actor grad during init', logger=logger) |
| if self._is_offload_optimizer: |
| offload_fsdp_optimizer(optimizer=self.actor_optimizer) |
| log_gpu_memory_usage('After offload actor optimizer during init', logger=logger) |
| |
| if self._is_actor: |
| OmegaConf.set_struct(self.config.actor, True) |
| self.actor = DataParallelPPOActor(config=self.config.actor, |
| actor_module=self.actor_module_fsdp, |
| actor_optimizer=self.actor_optimizer) |
|
|
| if self._is_rollout: |
| self.rollout, self.sharding_manager = self._build_rollout() |
|
|
| if self._is_ref: |
| self.ref_module_fsdp = self._build_model_optimizer(model_path=self.config.model.path, |
| fsdp_config=self.config.ref.fsdp_config, |
| optim_config=None, |
| override_model_config=override_model_config, |
| trust_remote_code=self.config.model.get( |
| 'trust_remote_code', False))[0] |
| if self._is_offload_param: |
| offload_fsdp_param_and_grad(module=self.ref_module_fsdp, offload_grad=self._is_offload_grad) |
|
|
| OmegaConf.set_struct(self.config.ref, True) |
| self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp) |
|
|
| torch.cuda.synchronize() |
| torch.distributed.barrier() |
| torch.cuda.empty_cache() |
|
|
| @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) |
| def update_actor(self, data: DataProto): |
| data = data.to('cuda') |
|
|
| assert self._is_actor |
| if self._is_offload_param: |
| load_fsdp_param_and_grad(module=self.actor_module_fsdp, |
| device_id=torch.cuda.current_device(), |
| load_grad=self._is_offload_grad) |
| if self._is_offload_optimizer: |
| load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=torch.cuda.current_device()) |
|
|
| data.batch = data.batch.cuda() |
|
|
| log_gpu_memory_usage('Before update policy', logger=logger) |
|
|
| metrics = self.actor.update_policy(data=data) |
|
|
| self.actor_lr_scheduler.step() |
| lr = self.actor_lr_scheduler.get_last_lr()[0] |
| metrics['actor/lr(1e-4)'] = lr * 1e4 |
|
|
| log_gpu_memory_usage('After update policy', logger=logger) |
|
|
| |
| output = DataProto(meta_info={'metrics': metrics}) |
| output = output.to('cpu') |
|
|
| if self._is_offload_param: |
| offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad) |
| if self._is_offload_optimizer: |
| offload_fsdp_optimizer(optimizer=self.actor_optimizer) |
| torch.cuda.synchronize() |
| torch.distributed.barrier() |
| torch.cuda.empty_cache() |
| return output |
|
|
| @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) |
| def compute_entropy(self, data: DataProto): |
| |
| data = data.to('cuda') |
|
|
| assert self._is_actor |
| if self._is_offload_param: |
| load_fsdp_param_and_grad(module=self.actor_module_fsdp, |
| device_id=torch.cuda.current_device(), |
| load_grad=self._is_offload_grad) |
|
|
| data.batch = data.batch.cuda() |
|
|
| log_gpu_memory_usage('Before compute entropy', logger=logger) |
|
|
| metrics = self.actor.compute_entropy(bacth_data=data) |
|
|
| log_gpu_memory_usage('After compute entropy', logger=logger) |
|
|
| |
| output = DataProto(meta_info={'metrics': metrics}) |
| output = output.to('cpu') |
| |
| if self._is_offload_param: |
| offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad) |
| if self._is_offload_optimizer: |
| offload_fsdp_optimizer(optimizer=self.actor_optimizer) |
| torch.cuda.synchronize() |
| torch.distributed.barrier() |
| torch.cuda.empty_cache() |
| return output |
|
|
| @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) |
| def generate_sequences(self, prompts: DataProto): |
| prompts = prompts.to('cuda') |
| |
| recompute_log_prob = prompts.meta_info.get('recompute_log_prob', True) |
|
|
| assert self._is_rollout |
| if self._is_offload_param: |
| load_fsdp_param_and_grad(module=self.actor_module_fsdp, |
| device_id=torch.cuda.current_device(), |
| load_grad=self._is_offload_grad) |
|
|
| prompts.batch = prompts.batch.cuda() |
| meta_info = {'eos_token_id': self.tokenizer.eos_token_id, 'pad_token_id': self.tokenizer.pad_token_id} |
| prompts.meta_info.update(meta_info) |
| with self.sharding_manager: |
| log_gpu_memory_usage('After entering sharding manager', logger=logger) |
|
|
| prompts = self.sharding_manager.preprocess_data(prompts) |
| output = self.rollout.generate_sequences(prompts=prompts) |
|
|
| log_gpu_memory_usage('After rollout generation', logger=logger) |
|
|
| output = self.sharding_manager.postprocess_data(output) |
| torch.cuda.synchronize() |
|
|
| if self._is_actor and recompute_log_prob: |
| |
| output.meta_info['micro_batch_size'] = self.config.rollout.log_prob_micro_batch_size |
| output.meta_info['temperature'] = self.config.rollout.temperature |
| output.meta_info['use_dynamic_bsz'] = self.config.rollout.log_prob_use_dynamic_bsz |
| output.meta_info['max_token_len'] = self.config.rollout.log_prob_max_token_len_per_gpu |
| old_log_probs = self.actor.compute_log_prob(data=output) |
| output.batch['old_log_probs'] = old_log_probs |
|
|
| output = output.to('cpu') |
|
|
| if self._is_offload_param: |
| |
| offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad) |
| |
| torch.cuda.synchronize() |
| torch.distributed.barrier() |
| torch.cuda.empty_cache() |
| log_gpu_memory_usage('After recompute log prob', logger=logger) |
| return output |
|
|
| @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) |
| def compute_ref_log_prob(self, data: DataProto): |
| assert self._is_ref |
|
|
| data = data.to('cuda') |
|
|
| if self._is_offload_param: |
| load_fsdp_param_and_grad(module=self.ref_module_fsdp, |
| device_id=torch.cuda.current_device(), |
| load_grad=self._is_offload_grad) |
|
|
| micro_batch_size = self.config.ref.log_prob_micro_batch_size |
| data.meta_info['micro_batch_size'] = micro_batch_size |
| data.meta_info['temperature'] = self.config.rollout.temperature |
| data.meta_info['max_token_len'] = self.config.ref.log_prob_max_token_len_per_gpu |
| data.meta_info['use_dynamic_bsz'] = self.config.ref.log_prob_use_dynamic_bsz |
| output = self.ref_policy.compute_log_prob(data=data) |
| output = DataProto.from_dict(tensors={'ref_log_prob': output}) |
|
|
| output = output.to('cpu') |
|
|
| if self._is_offload_param: |
| offload_fsdp_param_and_grad(module=self.ref_module_fsdp, offload_grad=self._is_offload_grad) |
| torch.cuda.synchronize() |
| torch.distributed.barrier() |
| torch.cuda.empty_cache() |
| return output |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def save_checkpoint(self, local_path, hdfs_path=None): |
| assert self._is_actor |
| import torch |
| if self._is_offload_param: |
| load_fsdp_param_and_grad(module=self.actor_module_fsdp, |
| device_id=torch.cuda.current_device(), |
| load_grad=self._is_offload_grad) |
|
|
| |
| import torch.distributed |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType, FullStateDictConfig |
| cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) |
| with FSDP.state_dict_type(self.actor.actor_module, StateDictType.FULL_STATE_DICT, cfg): |
| state_dict = self.actor.actor_module.state_dict() |
| if self.rank == 0: |
| print(f'Saving actor checkpoint to {local_path}') |
| os.makedirs(local_path, exist_ok=True) |
| self.actor_module.save_pretrained(local_path, state_dict=state_dict) |
| self.tokenizer.save_pretrained(local_path) |
| if hdfs_path is not None: |
| print(f'Uploading actor checkpoint to {hdfs_path}') |
| hdfs_io.makedirs(hdfs_path, exist_ok=True) |
| hdfs_io.copy(src=local_path, dst=hdfs_path) |
|
|
| torch.distributed.barrier() |
| if self._is_offload_param: |
| offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad) |
|
|
|
|
|
|
|
|
|
|
|
|
| class CriticWorker(Worker): |
|
|
| def __init__(self, config): |
| super().__init__() |
| import torch.distributed |
| if not torch.distributed.is_initialized(): |
| torch.distributed.init_process_group(backend="nccl") |
| self.config = config |
| self._is_offload_param = self.config.model.fsdp_config.param_offload |
| self._is_offload_grad = self.config.model.fsdp_config.grad_offload |
| self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload |
|
|
| |
| self.config.ppo_mini_batch_size //= torch.distributed.get_world_size() |
| self.config.ppo_micro_batch_size //= torch.distributed.get_world_size() |
|
|
| def _build_critic_model_optimizer(self, config): |
| |
| from verl.utils.model import LambdaLayer, print_model_size, squeeze |
| from verl.utils.torch_dtypes import PrecisionType |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, \ |
| CPUOffload |
| from torch import optim |
|
|
| local_path = copy_local_path_from_hdfs(config.model.path) |
| |
| |
| |
| from transformers import AutoTokenizer |
|
|
| tokenizer_path = copy_local_path_from_hdfs(config.model.tokenizer_path) |
| self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get('trust_remote_code', False)) |
|
|
| from omegaconf import OmegaConf |
| override_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create())) |
| override_config_kwargs = { |
| 'bos_token_id': self.tokenizer.bos_token_id, |
| 'eos_token_id': self.tokenizer.eos_token_id, |
| 'pad_token_id': self.tokenizer.pad_token_id, |
| } |
| override_config_kwargs.update(override_config) |
| if self.rank == 0: |
| print(f'Critic overriding config {override_config_kwargs}') |
|
|
| torch_dtype = self.config.model.fsdp_config.get('model_dtype', 'fp32') |
| torch_dtype = PrecisionType.to_dtype(torch_dtype) |
|
|
| from transformers import AutoConfig, AutoModelForCausalLM |
| from torch import nn |
|
|
| trust_remote_code = False |
| critic_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code) |
|
|
| init_context = get_init_weight_context_manager() |
| with init_context(), warnings.catch_warnings(): |
| warnings.simplefilter("ignore") |
| critic_module = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=local_path, |
| torch_dtype=torch_dtype, |
| config=critic_model_config, |
| attn_implementation='flash_attention_2', |
| trust_remote_code=trust_remote_code) |
| critic_module.lm_head = nn.Sequential(nn.Linear(critic_model_config.hidden_size, 1, dtype=torch_dtype), |
| LambdaLayer(fn=squeeze)) |
|
|
| |
| critic_module.to(torch_dtype) |
|
|
| if config.model.get('enable_gradient_checkpointing', False): |
| critic_module.gradient_checkpointing_enable() |
| if self.rank == 0: |
| print_model_size(critic_module) |
|
|
| fsdp_config = self.config.model.fsdp_config |
| mixed_precision_config = fsdp_config.get('mixed_precision', None) |
| if mixed_precision_config is not None: |
| param_dtype = PrecisionType.to_dtype(mixed_precision_config.get('param_dtype', 'bf16')) |
| reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get('reduce_dtype', 'fp32')) |
| buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get('buffer_dtype', 'fp32')) |
| else: |
| param_dtype = torch.bfloat16 |
| reduce_dtype = torch.float32 |
| buffer_dtype = torch.float32 |
|
|
| mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype) |
|
|
| auto_wrap_policy = get_fsdp_wrap_policy(module=critic_module, config=self.config.model.fsdp_config.wrap_policy) |
|
|
| log_gpu_memory_usage('Before critic FSDP', logger=None) |
|
|
| critic_module = FSDP(critic_module, |
| param_init_fn=init_fn, |
| use_orig_params=False, |
| auto_wrap_policy=auto_wrap_policy, |
| device_id=torch.cuda.current_device(), |
| sharding_strategy=ShardingStrategy.FULL_SHARD, |
| mixed_precision=mixed_precision, |
| sync_module_states=True) |
|
|
| log_gpu_memory_usage('After critic FSDP', logger=None) |
|
|
| critic_optimizer = optim.AdamW(critic_module.parameters(), |
| lr=config.optim.lr, |
| betas=config.optim.get('betas', (0.9, 0.999)), |
| weight_decay=config.optim.get('weight_decay', 1e-2)) |
|
|
| total_steps = config.optim.get('total_training_steps', 0) |
| num_warmup_steps_ratio = config.optim.get('lr_warmup_steps_ratio', 0.) |
| num_warmup_steps = int(num_warmup_steps_ratio * total_steps) |
|
|
| print(f'Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}') |
|
|
| from verl.utils.torch_functional import get_constant_schedule_with_warmup |
| critic_lr_scheduler = get_constant_schedule_with_warmup(optimizer=critic_optimizer, |
| num_warmup_steps=num_warmup_steps) |
|
|
| return critic_module, critic_optimizer, critic_lr_scheduler |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def init_model(self): |
| |
| import_external_libs(self.config.model.get('external_lib', None)) |
|
|
| from verl.workers.critic import DataParallelPPOCritic |
| self.critic_module, self.critic_optimizer, self.critic_lr_scheduler = self._build_critic_model_optimizer( |
| self.config) |
|
|
| if self._is_offload_param: |
| offload_fsdp_param_and_grad(module=self.critic_module, offload_grad=self._is_offload_grad) |
| if self._is_offload_optimizer: |
| offload_fsdp_optimizer(optimizer=self.critic_optimizer) |
|
|
| self.critic = DataParallelPPOCritic(config=self.config, |
| critic_module=self.critic_module, |
| critic_optimizer=self.critic_optimizer) |
| torch.cuda.empty_cache() |
|
|
| @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) |
| def compute_values(self, data: DataProto): |
| data = data.to('cuda') |
|
|
| if self._is_offload_param: |
| load_fsdp_param_and_grad(module=self.critic_module, |
| device_id=torch.cuda.current_device(), |
| load_grad=self._is_offload_grad) |
| micro_batch_size = self.config.ppo_micro_batch_size |
| data.meta_info['micro_batch_size'] = micro_batch_size |
| data.meta_info['max_token_len'] = self.config.forward_max_token_len_per_gpu |
| data.meta_info['use_dynamic_bsz'] = self.config.use_dynamic_bsz |
| values = self.critic.compute_values(data=data) |
| output = DataProto.from_dict(tensors={'values': values}) |
| output = output.to('cpu') |
| if self._is_offload_param: |
| offload_fsdp_param_and_grad(module=self.critic_module, offload_grad=self._is_offload_grad) |
| torch.cuda.empty_cache() |
| return output |
|
|
| @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) |
| def update_critic(self, data: DataProto): |
| data = data.to('cuda') |
| if self._is_offload_param: |
| load_fsdp_param_and_grad(module=self.critic_module, |
| device_id=torch.cuda.current_device(), |
| load_grad=self._is_offload_grad) |
| if self._is_offload_optimizer: |
| load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=torch.cuda.current_device()) |
| metrics = self.critic.update_critic(data=data) |
|
|
| self.critic_lr_scheduler.step() |
| lr = self.critic_lr_scheduler.get_last_lr()[0] |
| metrics['critic/lr(1e-4)'] = lr * 1e4 |
|
|
| output = DataProto(batch=None, meta_info={'metrics': metrics}) |
| if self._is_offload_param: |
| offload_fsdp_param_and_grad(module=self.critic_module, offload_grad=self._is_offload_grad) |
| if self._is_offload_optimizer: |
| offload_fsdp_optimizer(optimizer=self.critic_optimizer) |
| torch.cuda.empty_cache() |
| output = output.to('cpu') |
| return output |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def save_checkpoint(self, local_path, hdfs_path=None): |
| import torch |
| if self._is_offload_param: |
| load_fsdp_param_and_grad(module=self.critic_module, |
| device_id=torch.cuda.current_device(), |
| load_grad=self._is_offload_grad) |
|
|
| |
| import torch.distributed |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType, FullStateDictConfig |
| cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) |
| with FSDP.state_dict_type(self.critic_module, StateDictType.FULL_STATE_DICT, cfg): |
| state_dict = self.critic_module.state_dict() |
| if self.rank == 0: |
| print(f'Saving critic checkpoint to {local_path}') |
| os.makedirs(local_path, exist_ok=True) |
| self.critic_module._fsdp_wrapped_module.save_pretrained(local_path, state_dict=state_dict) |
| self.tokenizer.save_pretrained(local_path) |
| if hdfs_path is not None: |
| print(f'Uploading critic checkpoint to {hdfs_path}') |
| hdfs_io.makedirs(hdfs_path, exist_ok=True) |
| hdfs_io.copy(src=local_path, dst=hdfs_path) |
|
|
| torch.distributed.barrier() |
| if self._is_offload_param: |
| offload_fsdp_param_and_grad(module=self.critic_module, offload_grad=self._is_offload_grad) |
|
|
|
|
| class RewardModelWorker(Worker): |
| """ |
| Note that we only implement the reward model that is subclass of AutoModelForSequenceClassification. |
| """ |
|
|
| def __init__(self, config): |
| super().__init__() |
| import torch.distributed |
| if not torch.distributed.is_initialized(): |
| torch.distributed.init_process_group(backend="nccl") |
| self.config = config |
|
|
| self.config.micro_batch_size //= torch.distributed.get_world_size() |
|
|
| def _build_model(self, config): |
| |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, CPUOffload |
|
|
| |
| local_path = copy_local_path_from_hdfs(config.model.path) |
|
|
| if self.config.model.input_tokenizer is None: |
| self._do_switch_chat_template = False |
| else: |
| self._do_switch_chat_template = True |
| input_tokenizer_local_path = copy_local_path_from_hdfs(config.model.input_tokenizer) |
| self.input_tokenizer = hf_tokenizer(input_tokenizer_local_path, |
| trust_remote_code=config.model.get('trust_remote_code', False)) |
| self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.model.get('trust_remote_code', False)) |
|
|
| trust_remote_code = config.model.get('trust_remote_code', False) |
| model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code) |
| |
| init_context = get_init_weight_context_manager(use_meta_tensor=not model_config.tie_word_embeddings) |
|
|
| with init_context(), warnings.catch_warnings(): |
| warnings.simplefilter("ignore") |
| reward_module = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path=local_path, |
| torch_dtype=torch.bfloat16, |
| attn_implementation='flash_attention_2', |
| trust_remote_code=trust_remote_code) |
| reward_module.to(torch.bfloat16) |
| auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config) |
|
|
| reward_module = FSDP( |
| reward_module, |
| param_init_fn=init_fn, |
| use_orig_params=False, |
| auto_wrap_policy=auto_wrap_policy, |
| device_id=torch.cuda.current_device(), |
| sharding_strategy=ShardingStrategy.FULL_SHARD, |
| sync_module_states=True, |
| cpu_offload=CPUOffload(offload_params=self.config.model.fsdp_config.param_offload)) |
|
|
| return reward_module |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def init_model(self): |
| |
| import_external_libs(self.config.model.get('external_lib', None)) |
| self.reward_module = self._build_model(config=self.config) |
| torch.cuda.empty_cache() |
|
|
| def _forward_micro_batch(self, micro_batch): |
| with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16): |
| output = self.reward_module(input_ids=micro_batch['input_ids'], |
| attention_mask=micro_batch['attention_mask'], |
| position_ids=micro_batch['position_ids']) |
| rm_score = output.logits |
| rm_score = rm_score.squeeze(-1) |
| return rm_score |
|
|
| def _expand_to_token_level(self, data: DataProto, scores: torch.Tensor): |
| batch_size = data.batch.batch_size[0] |
| |
| attention_mask = data.batch['attention_mask'] |
| position_ids = data.batch['position_ids'] |
| response_length = data.batch['responses'].shape[-1] |
| eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) |
| token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype) |
| token_level_scores[torch.arange(batch_size), eos_mask_idx] = scores |
|
|
| |
| token_level_scores = token_level_scores[:, -response_length:] |
|
|
| return token_level_scores |
|
|
| def _switch_chat_template(self, data: DataProto): |
| src_max_length = data.batch['attention_mask'].shape[-1] |
|
|
| src_tokenizer = self.input_tokenizer |
| target_tokenizer = self.tokenizer |
|
|
| rm_input_ids = [] |
| rm_attention_mask = [] |
|
|
| for i in range(data.batch.batch_size[0]): |
| |
| chat: list = data.non_tensor_batch['raw_prompt'][i].tolist() |
|
|
| |
| response_ids = data.batch['responses'][i] |
| response_length = response_ids.shape[-1] |
| valid_response_length = data.batch['attention_mask'][i][-response_length:].sum() |
| valid_response_ids = response_ids[:valid_response_length] |
|
|
| |
| response = src_tokenizer.decode(valid_response_ids) |
| |
| response = response.replace(src_tokenizer.eos_token, '') |
|
|
| chat.append({'role': 'assistant', 'content': response}) |
|
|
| prompt_with_chat_template = target_tokenizer.apply_chat_template(chat, |
| add_generation_prompt=False, |
| tokenize=False) |
| if self.rank == 0 and i == 0: |
| |
| print(f'Switch template. chat: {prompt_with_chat_template}') |
|
|
| |
| max_length = self.config.get('max_length', src_max_length) |
| if max_length is None: |
| max_length = src_max_length |
| input_ids, attention_mask = verl_F.tokenize_and_postprocess_data( |
| prompt=prompt_with_chat_template, |
| tokenizer=target_tokenizer, |
| max_length=max_length, |
| pad_token_id=target_tokenizer.pad_token_id, |
| left_pad=False, |
| truncation=self.config.get('truncation', 'right')) |
|
|
| rm_input_ids.append(input_ids) |
| rm_attention_mask.append(attention_mask) |
|
|
| rm_input_ids = torch.cat(rm_input_ids, dim=0) |
| rm_attention_mask = torch.cat(rm_attention_mask, dim=0) |
|
|
| rm_position_ids = compute_position_id_with_mask(rm_attention_mask) |
|
|
| rm_inputs = {'input_ids': rm_input_ids, 'attention_mask': rm_attention_mask, 'position_ids': rm_position_ids} |
|
|
| return DataProto.from_dict(rm_inputs) |
|
|
| @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) |
| def compute_rm_score(self, data: DataProto): |
| data = data.to('cuda') |
| if self._do_switch_chat_template: |
| rm_data = self._switch_chat_template(data) |
|
|
| rm_data.batch = rm_data.batch.cuda() |
| micro_batches = rm_data.batch.split(self.config.micro_batch_size) |
| output = [] |
| for micro_batch in micro_batches: |
| rm_score = self._forward_micro_batch(micro_batch) |
| output.append(rm_score) |
| scores = torch.cat(output, dim=0) |
| token_level_scores = self._expand_to_token_level(data, scores) |
| |
| output = DataProto.from_dict(tensors={'rm_scores': token_level_scores}) |
| output = output.to('cpu') |
| torch.cuda.empty_cache() |
| return output |
|
|
| class PRIMERewardModelWorker(Worker): |
| """ |
| PRIME reward model. |
| Can update itself whenever compute_rm_score is called. |
| """ |
| def __init__(self, config): |
| super().__init__() |
| import torch.distributed |
| if not torch.distributed.is_initialized(): |
| torch.distributed.init_process_group(backend="nccl") |
| self.config = config |
|
|
| world_size = torch.distributed.get_world_size() |
| self.config.mini_batch_size //= world_size |
| self.config.micro_batch_size //= world_size |
| |
| |
| from torch.distributed.device_mesh import init_device_mesh |
| |
| self.device_mesh = init_device_mesh('cuda', mesh_shape=(world_size,), mesh_dim_names=['fsdp']) |
|
|
| self._is_offload_param = self.config.prime_model.fsdp_config.get('param_offload', False) |
| self._is_offload_grad = self.config.prime_model.fsdp_config.get('grad_offload', False) |
| self._is_offload_optimizer = self.config.prime_model.fsdp_config.get('optimizer_offload', False) |
|
|
| def _build_model_optimizer(self, config, enable_gradient_checkpointing=False): |
| |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, CPUOffload |
|
|
| |
| local_path = copy_local_path_from_hdfs(config.prime_model.path) |
|
|
| if self.config.prime_model.input_tokenizer is None: |
| self._do_switch_chat_template = False |
| else: |
| self._do_switch_chat_template = True |
| input_tokenizer_local_path = copy_local_path_from_hdfs(config.prime_model.input_tokenizer) |
| self.input_tokenizer = hf_tokenizer(input_tokenizer_local_path, |
| trust_remote_code=config.prime_model.get('trust_remote_code', False)) |
| self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.prime_model.get('trust_remote_code', False)) |
|
|
| trust_remote_code = config.prime_model.get('trust_remote_code', False) |
| model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code) |
| |
| if config.prime_model.use_remove_padding: |
| from verl.models.registry import check_model_support_rmpad |
| check_model_support_rmpad(model_config.model_type) |
| init_context = get_init_weight_context_manager(use_meta_tensor=not model_config.tie_word_embeddings) |
|
|
| with init_context(), warnings.catch_warnings(): |
| warnings.simplefilter("ignore") |
| from liger_kernel.transformers import AutoLigerKernelForCausalLM |
| reward_module = AutoLigerKernelForCausalLM.from_pretrained(pretrained_model_name_or_path=local_path, |
| torch_dtype=torch.float32, |
| attn_implementation='flash_attention_2', |
| trust_remote_code=trust_remote_code) |
| reward_module.to(torch.float32) |
| if enable_gradient_checkpointing: |
| reward_module.gradient_checkpointing_enable() |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision |
| mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, |
| buffer_dtype=torch.float32) |
| if config.prime_model.get('enable_gradient_checkpointing', False): |
| reward_module.gradient_checkpointing_enable() |
|
|
| if config.prime_model.get("ref_type", 'freeze') == 'freeze': |
| reference_module = AutoLigerKernelForCausalLM.from_pretrained( |
| pretrained_model_name_or_path=copy_local_path_from_hdfs(config.prime_model.ref_path), |
| torch_dtype=torch.bfloat16, |
| attn_implementation='flash_attention_2', |
| trust_remote_code=trust_remote_code) |
| reference_module.to(torch.bfloat16) |
| for param in reference_module.parameters(): |
| param.requires_grad = False |
| else: |
| reference_module = None |
|
|
| auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.prime_model.fsdp_config) |
|
|
| reward_module = FSDP( |
| reward_module, |
| param_init_fn=init_fn, |
| use_orig_params=False, |
| auto_wrap_policy=auto_wrap_policy, |
| device_id=torch.cuda.current_device(), |
| sharding_strategy=ShardingStrategy.FULL_SHARD, |
| mixed_precision=mixed_precision, |
| device_mesh=self.device_mesh, |
| sync_module_states=True) |
|
|
| auto_wrap_policy = get_fsdp_wrap_policy(module=reference_module, config=self.config.prime_model.fsdp_config) |
| if reference_module is not None: |
| reference_module = FSDP( |
| reference_module, |
| param_init_fn=init_fn, |
| use_orig_params=False, |
| auto_wrap_policy=auto_wrap_policy, |
| device_id=torch.cuda.current_device(), |
| sharding_strategy=ShardingStrategy.FULL_SHARD, |
| device_mesh=self.device_mesh, |
| sync_module_states=True) |
|
|
| self.update_dpo_type = self.config.prime_model.get('update', 'none') |
| if self.update_dpo_type in ['before', 'after']: |
|
|
| from torch import optim |
| self.reward_optimizer = optim.AdamW(reward_module.parameters(), |
| lr=config.prime_model.optim.lr, |
| betas=config.prime_model.optim.get('betas', (0.9, 0.999)), |
| weight_decay=config.prime_model.optim.get('weight_decay', 1e-2)) |
|
|
| total_steps = config.prime_model.optim.get('total_training_steps', 0) |
| num_warmup_steps_ratio = config.prime_model.optim.get('lr_warmup_steps_ratio', 0.) |
| num_warmup_steps = int(num_warmup_steps_ratio * total_steps) |
|
|
| print(f'Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}') |
|
|
| from verl.utils.torch_functional import get_constant_schedule_with_warmup |
| self.reward_lr_scheduler = get_constant_schedule_with_warmup(optimizer=self.reward_optimizer, |
| num_warmup_steps=num_warmup_steps) |
|
|
| |
| if self._is_offload_optimizer: |
| offload_fsdp_optimizer(optimizer=self.reward_optimizer) |
|
|
| if self._is_offload_param: |
| offload_fsdp_param_and_grad(module=reward_module, offload_grad=self._is_offload_grad) |
| if reference_module is not None: |
| offload_fsdp_param_and_grad(module=reference_module, offload_grad=self._is_offload_grad) |
|
|
| return reward_module, reference_module |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def init_model(self): |
| from verl.workers.actor import DataParallelPRIME |
| |
| import_external_libs(self.config.prime_model.get('external_lib', None)) |
| self.reward_module, self.reference_module = self._build_model_optimizer(config=self.config, enable_gradient_checkpointing=self.config.prime_model.get('enable_gradient_checkpointing', False)) |
| self.prm = DataParallelPRIME(config=self.config, |
| reward_module=self.reward_module, |
| reference_module=self.reference_module, |
| reward_optimizer=self.reward_optimizer, |
| prime_loss_fn=self.config.prime_model.get('loss_type', 'ce')) |
| torch.cuda.empty_cache() |
|
|
| def _switch_chat_template(self, data: DataProto): |
| src_max_length = data.batch['attention_mask'].shape[-1] |
|
|
| src_tokenizer = self.input_tokenizer |
| target_tokenizer = self.tokenizer |
|
|
| rm_input_ids = [] |
| rm_attention_mask = [] |
|
|
| for i in range(data.batch.batch_size[0]): |
| |
| chat: list = data.non_tensor_batch['raw_prompt'][i].tolist() |
|
|
| |
| response_ids = data.batch['responses'][i] |
| response_length = response_ids.shape[-1] |
| valid_response_length = data.batch['attention_mask'][i][-response_length:].sum() |
| valid_response_ids = response_ids[:valid_response_length] |
|
|
| |
| response = src_tokenizer.decode(valid_response_ids) |
| |
| response = response.replace(src_tokenizer.eos_token, '') |
|
|
| chat.append({'role': 'assistant', 'content': response}) |
|
|
| prompt_with_chat_template = target_tokenizer.apply_chat_template(chat, |
| add_generation_prompt=False, |
| tokenize=False) |
| if self.rank == 0 and i == 0: |
| |
| print(f'Switch template. chat: {prompt_with_chat_template}') |
|
|
| |
| max_length = self.config.get('max_length', src_max_length) |
| if max_length is None: |
| max_length = src_max_length |
| input_ids, attention_mask = verl_F.tokenize_and_postprocess_data( |
| prompt=prompt_with_chat_template, |
| tokenizer=target_tokenizer, |
| max_length=max_length, |
| pad_token_id=target_tokenizer.pad_token_id, |
| left_pad=False, |
| truncation=self.config.get('truncation', 'right')) |
|
|
| rm_input_ids.append(input_ids) |
| rm_attention_mask.append(attention_mask) |
|
|
| rm_input_ids = torch.cat(rm_input_ids, dim=0) |
| rm_attention_mask = torch.cat(rm_attention_mask, dim=0) |
|
|
| rm_position_ids = compute_position_id_with_mask(rm_attention_mask) |
|
|
| rm_inputs = {'input_ids': rm_input_ids, 'attention_mask': rm_attention_mask, 'position_ids': rm_position_ids} |
|
|
| return DataProto.from_dict(rm_inputs) |
|
|
| @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) |
| def compute_rm_score(self, data: DataProto): |
| n_samples=data.meta_info['n_samples'] |
| beta=self.config.prime_model.get('beta_train', 0.05) |
| if self._do_switch_chat_template: |
| rm_data = self._switch_chat_template(data) |
| else: |
| rm_data=data |
|
|
| if self.update_dpo_type!='none': |
| if self._is_offload_optimizer: |
| load_fsdp_optimizer(optimizer=self.reward_optimizer, device_id=torch.cuda.current_device()) |
| if self._is_offload_param: |
| load_fsdp_param_and_grad(module=self.reward_module,device_id=torch.cuda.current_device(),load_grad=self._is_offload_grad) |
| if self.reference_module is not None: |
| load_fsdp_param_and_grad(module=self.reference_module,device_id=torch.cuda.current_device(),load_grad=self._is_offload_grad) |
| |
| token_level_scores, metrics = self.prm.update_policy(rm_data) |
|
|
| output=DataProto.from_dict(tensors = {'rm_scores': token_level_scores}, meta_info = {'metrics': metrics}) |
|
|
| if self.update_dpo_type != 'none': |
| if self._is_offload_optimizer: |
| offload_fsdp_optimizer(optimizer=self.reward_optimizer) |
| self.reward_lr_scheduler.step() |
| if self._is_offload_param: |
| offload_fsdp_param_and_grad(module=self.reward_module, offload_grad=self._is_offload_grad) |
| if self.reference_module is not None: |
| offload_fsdp_param_and_grad(module=self.reference_module, offload_grad=self._is_offload_grad) |
|
|
| output = output.to('cpu') |
| torch.cuda.empty_cache() |
| return output |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def save_checkpoint(self, local_path, hdfs_path=None): |
| import torch |
| if self._is_offload_param: |
| load_fsdp_param_and_grad(module=self.reward_module, |
| device_id=torch.cuda.current_device(), |
| load_grad=self._is_offload_grad) |
|
|
| |
| import torch.distributed |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType, FullStateDictConfig |
| cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) |
| with FSDP.state_dict_type(self.reward_module, StateDictType.FULL_STATE_DICT, cfg): |
| state_dict = self.reward_module.state_dict() |
| if self.rank == 0: |
| print(f'Saving reward checkpoint to {local_path}') |
| os.makedirs(local_path, exist_ok=True) |
| self.reward_module._fsdp_wrapped_module.save_pretrained(local_path, state_dict=state_dict) |
| if hdfs_path is not None: |
| print(f'Uploading reward checkpoint to {hdfs_path}') |
| hdfs_io.makedirs(hdfs_path, exist_ok=True) |
| hdfs_io.copy(src=local_path, dst=hdfs_path) |
|
|
| torch.distributed.barrier() |
| if self._is_offload_param: |
| offload_fsdp_param_and_grad(module=self.reward_module, offload_grad=self._is_offload_grad) |
|
|