| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| The main entry point to run the PPO algorithm |
| """ |
|
|
| import datetime |
| import json |
| import logging |
| import os |
| import warnings |
| from dataclasses import asdict |
|
|
| import psutil |
| import torch |
| import torch.distributed |
| import torch.distributed as dist |
| from codetiming import Timer |
| from omegaconf import DictConfig, OmegaConf, open_dict |
| from omegaconf.errors import ConfigAttributeError |
| from peft import LoraConfig, TaskType, get_peft_model |
| from safetensors.torch import save_file |
| from torch.distributed.device_mesh import init_device_mesh |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| from torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType |
|
|
| try: |
| |
| from torch.distributed.tensor import DTensor |
| except ImportError: |
| from torch.distributed._tensor import DTensor |
|
|
| from verl import DataProto |
| from verl.models.transformers.monkey_patch import apply_monkey_patch |
| from verl.single_controller.base import Worker |
| from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register |
| from verl.utils import hf_processor, hf_tokenizer |
| from verl.utils.activation_offload import enable_activation_offloading |
| from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager |
| from verl.utils.config import omega_conf_to_dataclass |
| from verl.utils.device import ( |
| get_device_id, |
| get_device_name, |
| get_nccl_backend, |
| get_torch_device, |
| set_expandable_segments, |
| ) |
| from verl.utils.flops_counter import FlopsCounter |
| from verl.utils.fs import copy_to_local |
| from verl.utils.fsdp_utils import ( |
| CPUOffloadPolicy, |
| MixedPrecisionPolicy, |
| apply_fsdp2, |
| collect_lora_params, |
| fsdp2_load_full_state_dict, |
| fsdp_version, |
| get_fsdp_wrap_policy, |
| get_init_weight_context_manager, |
| get_shard_placement_fn, |
| init_fn, |
| layered_summon_lora_params, |
| load_fsdp_model_to_gpu, |
| load_fsdp_optimizer, |
| offload_fsdp_model_to_cpu, |
| offload_fsdp_optimizer, |
| replace_lora_wrapper, |
| ) |
| from verl.utils.import_utils import import_external_libs |
| from verl.utils.memory_utils import aggressive_empty_cache |
| from verl.utils.model import convert_weight_keys |
| from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage, simple_timer |
| from verl.utils.profiler.performance import reduce_timing, topk_reduce_ratio_min_max |
| from verl.utils.py_functional import convert_to_regular_types |
|
|
| |
| from verl.utils.qat import apply_qat, enable_qat_fuse |
| from verl.utils.ray_utils import get_event_loop |
| from verl.utils.transformers_compat import get_auto_model_for_vision2seq |
| from verl.workers.config import FSDPCriticConfig, FSDPEngineConfig, HFModelConfig, RolloutConfig |
| from verl.workers.config.optimizer import build_optimizer |
| from verl.workers.rollout import get_rollout_class |
| from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager |
|
|
| logger = logging.getLogger(__file__) |
| logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) |
|
|
| device_name = get_device_name() |
|
|
|
|
| def create_device_mesh(world_size, fsdp_size): |
| if fsdp_size < 0 or fsdp_size >= world_size: |
| device_mesh = init_device_mesh(device_name, mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) |
| else: |
| device_mesh = init_device_mesh( |
| device_name, mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"] |
| ) |
| return device_mesh |
|
|
|
|
| def get_sharding_strategy(device_mesh, zero3_enable=True): |
| from torch.distributed.fsdp import ShardingStrategy |
|
|
| if zero3_enable: |
| fsdp_strategy = ShardingStrategy.FULL_SHARD |
| hsdp_strategy = ShardingStrategy.HYBRID_SHARD |
| else: |
| fsdp_strategy = ShardingStrategy.SHARD_GRAD_OP |
| hsdp_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2 |
|
|
| if device_mesh.ndim == 1: |
| sharding_strategy = fsdp_strategy |
| elif device_mesh.ndim == 2: |
| sharding_strategy = hsdp_strategy |
| else: |
| raise NotImplementedError(f"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2") |
| return sharding_strategy |
|
|
|
|
| def get_vl_model_vision_tower(vl_model_instance): |
| """ |
| Util to extract Vision Tower from a VL model instance |
| """ |
| if hasattr(vl_model_instance, "model") and hasattr(vl_model_instance.model, "visual"): |
| |
| return vl_model_instance.model.visual |
| elif hasattr(vl_model_instance, "visual"): |
| |
| return vl_model_instance.visual |
| return None |
|
|
|
|
| class ActorRolloutRefWorker(Worker, DistProfilerExtension): |
| """ |
| This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy |
| or a hybrid engine based on the config.rollout |
| """ |
|
|
| def __init__(self, config: DictConfig, role: str, **kwargs): |
| Worker.__init__(self) |
|
|
| self.config = config |
| import torch.distributed |
|
|
| if not torch.distributed.is_initialized(): |
| rank = int(os.environ.get("RANK", 0)) |
| world_size = int(os.environ.get("WORLD_SIZE", 1)) |
| torch.distributed.init_process_group( |
| backend=f"cpu:gloo,{get_device_name()}:{get_nccl_backend()}", |
| rank=rank, |
| world_size=world_size, |
| timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), |
| init_method=os.environ.get("DIST_INIT_METHOD", None), |
| ) |
|
|
| |
| from verl.workers.engine.fsdp.utils import apply_npu_fsdp_patches |
|
|
| apply_npu_fsdp_patches() |
|
|
| |
| world_size = torch.distributed.get_world_size() |
| |
| self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=self.config.actor.fsdp_config.fsdp_size) |
|
|
| |
| self.ulysses_device_mesh = None |
| self.ulysses_sequence_parallel_size = self.config.actor.get("ulysses_sequence_parallel_size", 1) |
| dp = world_size // self.ulysses_sequence_parallel_size |
| if self.ulysses_sequence_parallel_size > 1: |
| self.ulysses_device_mesh = init_device_mesh( |
| device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] |
| ) |
|
|
| |
| if self.ulysses_device_mesh is not None: |
| is_collect = self.ulysses_device_mesh["sp"].get_local_rank() == 0 |
| self._register_dispatch_collect_info( |
| "actor", dp_rank=self.ulysses_device_mesh["dp"].get_local_rank(), is_collect=is_collect |
| ) |
| else: |
| self._register_dispatch_collect_info("actor", dp_rank=self.rank, is_collect=True) |
|
|
| self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) |
| self._lora_rank = self.config.model.get("lora_rank", 0) |
| self._is_lora = self.config.model.get("lora_adapter_path") is not None or self._lora_rank > 0 |
|
|
| self.role = role |
| assert self.role in ["actor", "rollout", "ref", "actor_rollout", "actor_rollout_ref"] |
|
|
| self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"] |
| self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"] |
| self._is_ref = self.role in ["ref", "actor_rollout_ref"] |
| self.use_orig_params = self.config.actor.fsdp_config.get("use_orig_params", False) |
|
|
| |
| |
| |
| |
| |
| |
| if self._is_actor: |
| omega_profiler_config = config.actor.get("profiler", {}) |
| elif self._is_rollout: |
| |
| |
| omega_profiler_config = config.rollout.get("profiler", {}) |
| elif self._is_ref: |
| omega_profiler_config = config.ref.get("profiler", {}) |
| else: |
| raise ValueError( |
| f"Invalid role {self.role}, should be one of " |
| "['actor', 'rollout', 'ref', 'actor_rollout', 'actor_rollout_ref']" |
| ) |
| |
| |
| profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig) |
| if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]: |
| tool_config = omega_conf_to_dataclass( |
| omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool")) |
| ) |
| else: |
| tool_config = None |
| DistProfilerExtension.__init__( |
| self, DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config) |
| ) |
|
|
| self._is_offload_param = False |
| self._is_offload_optimizer = False |
| if self._is_actor: |
| self._is_offload_param = self.config.actor.fsdp_config.get("param_offload", False) |
| self._is_offload_optimizer = self.config.actor.fsdp_config.get("optimizer_offload", False) |
| elif self._is_ref: |
| |
| self._is_offload_param = self.config.ref.fsdp_config.get("param_offload", False) |
|
|
| |
| if self._is_actor: |
| self.config.actor.ppo_mini_batch_size *= self.config.rollout.n |
| self.config.actor.ppo_mini_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size |
| assert self.config.actor.ppo_mini_batch_size > 0, ( |
| f"ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than 0 after " |
| f"normalization" |
| ) |
| |
| if self.config.actor.ppo_micro_batch_size is not None: |
| self.config.actor.ppo_micro_batch_size //= ( |
| self.device_mesh.size() // self.ulysses_sequence_parallel_size |
| ) |
| self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size |
|
|
| if self.config.actor.ppo_micro_batch_size_per_gpu is not None: |
| assert self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0, ( |
| f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be divisible by " |
| f"ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" |
| ) |
| assert self.config.actor.ppo_mini_batch_size // self.config.actor.ppo_micro_batch_size_per_gpu > 0, ( |
| f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than " |
| f"ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" |
| ) |
|
|
| |
| if self._is_rollout and self.config.rollout.log_prob_micro_batch_size is not None: |
| self.config.rollout.log_prob_micro_batch_size //= ( |
| self.device_mesh.size() // self.ulysses_sequence_parallel_size |
| ) |
| self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size |
| |
| if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None: |
| self.config.ref.log_prob_micro_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size |
| self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size |
|
|
| def _init_qat_config(self): |
| """Initialize QAT configuration from actor.qat.""" |
| try: |
| self.qat_config = self.config.actor.qat |
| self._qat_enabled = self.qat_config.enable |
| if self._qat_enabled: |
| logger.info( |
| f"QAT enabled: mode={self.qat_config.mode}, config_path={self.qat_config.quantization_config_path}" |
| ) |
| except (AttributeError, KeyError, ConfigAttributeError): |
| |
| self._qat_enabled = False |
| self.qat_config = None |
|
|
| def _restore_w4a4_input_scales(self, model, model_path): |
| """Restore input_global_scale and input_amax from checkpoint for W4A4 mode.""" |
| import glob |
|
|
| from safetensors import safe_open |
|
|
| safetensor_files = glob.glob(f"{model_path}/model*.safetensors") |
| loaded_count = 0 |
|
|
| for sf_path in safetensor_files: |
| with safe_open(sf_path, framework="pt") as f: |
| for key in f.keys(): |
| if "input_global_scale" in key: |
| module_path = key.replace(".input_global_scale", "") |
| amax_key = f"{module_path}.input_amax" |
|
|
| module = model |
| for part in module_path.split("."): |
| module = getattr(module, part) |
|
|
| scale_val = f.get_tensor(key) |
| val = scale_val.item() if scale_val.numel() == 1 else scale_val.max().item() |
| module.input_global_scale.fill_(val) |
|
|
| amax_val = f.get_tensor(amax_key) |
| amax = amax_val.item() if amax_val.numel() == 1 else amax_val.max().item() |
| module.input_amax.fill_(amax) |
| loaded_count += 1 |
|
|
| if self.rank == 0: |
| logger.info(f"[W4A4] Loaded {loaded_count} input scales from checkpoint") |
|
|
| def _build_model_optimizer( |
| self, |
| model_path, |
| fsdp_config: FSDPEngineConfig, |
| optim_config, |
| override_model_config, |
| use_remove_padding=False, |
| use_fused_kernels=False, |
| enable_gradient_checkpointing=False, |
| trust_remote_code=False, |
| use_liger=False, |
| role="actor", |
| enable_activation_offload=False, |
| use_prefix_grouper=False, |
| use_tiled_mlp=False, |
| tiled_mlp_shards=4, |
| ): |
| from torch.distributed.fsdp import CPUOffload, MixedPrecision |
| from transformers import ( |
| AutoConfig, |
| AutoModel, |
| AutoModelForCausalLM, |
| ) |
|
|
| try: |
| from transformers import AutoModelForVision2Seq |
| except ImportError: |
| AutoModelForVision2Seq = None |
| try: |
| from transformers import AutoModelForImageTextToText |
| except ImportError: |
| AutoModelForImageTextToText = AutoModelForVision2Seq |
|
|
| from verl.utils.model import get_generation_config, print_model_size, update_model_config |
| from verl.utils.torch_dtypes import PrecisionType |
|
|
| AutoModelForVision2Seq = get_auto_model_for_vision2seq() |
|
|
| assert role in ["actor", "ref"] |
|
|
| |
| if use_tiled_mlp and self.config.actor.strategy == "fsdp": |
| raise ValueError("TiledMLP requires FSDP2. Set `actor_rollout_ref.actor.strategy=fsdp2`.") |
|
|
| log_gpu_memory_usage(f"Before init {role} from HF AutoModel", logger=logger) |
| local_path = model_path |
|
|
| |
| |
| self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) |
| self.processor = hf_processor(local_path, trust_remote_code=trust_remote_code) |
|
|
| if self.config.model.get("custom_chat_template", None) is not None: |
| if self.processor is not None: |
| self.processor.chat_template = self.config.model.custom_chat_template |
| else: |
| self.tokenizer.chat_template = self.config.model.custom_chat_template |
|
|
| torch_dtype = fsdp_config.get("model_dtype", None) |
| if torch_dtype is None: |
| torch_dtype = torch.float32 if self._is_actor else torch.bfloat16 |
| else: |
| torch_dtype = PrecisionType.to_dtype(torch_dtype) |
|
|
| |
| attn_implementation = override_model_config.get("attn_implementation", "flash_attention_2") |
| actor_model_config = AutoConfig.from_pretrained( |
| local_path, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation |
| ) |
| |
| |
| |
| if self.ulysses_sequence_parallel_size > 1 and hasattr(actor_model_config, "vision_config"): |
| actor_model_config.vision_config._attn_implementation = "eager" |
|
|
| |
| |
| if ( |
| getattr(actor_model_config, "model_type", None) == "qwen2_5_vl" |
| and attn_implementation == "flash_attention_3" |
| and hasattr(actor_model_config, "vision_config") |
| ): |
| actor_model_config.vision_config._attn_implementation = "flash_attention_2" |
|
|
| |
| if getattr(actor_model_config, "model_type", None) == "kimi_vl": |
| actor_model_config.text_config.topk_method = "greedy" |
|
|
| self.generation_config = get_generation_config(local_path, trust_remote_code=trust_remote_code) |
|
|
| override_config_kwargs = { |
| "bos_token_id": self.tokenizer.bos_token_id, |
| "eos_token_id": self.tokenizer.eos_token_id, |
| "pad_token_id": self.tokenizer.pad_token_id, |
| } |
|
|
| if self.config.model.get("mtp", {}).get("enable", False): |
| raise NotImplementedError("Right now, MTP is not supported in FSDP") |
| else: |
| if hasattr(actor_model_config, "num_nextn_predict_layers"): |
| actor_model_config.num_nextn_predict_layers = 0 |
|
|
| override_config_kwargs.update(override_model_config) |
| update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs) |
| if self.rank == 0: |
| print(f"Model config after override: {actor_model_config}") |
|
|
| |
| init_context = get_init_weight_context_manager( |
| use_meta_tensor=not actor_model_config.tie_word_embeddings, mesh=self.device_mesh |
| ) |
|
|
| with init_context(), warnings.catch_warnings(): |
| warnings.simplefilter("ignore") |
| has_remote_code = hasattr(actor_model_config, "auto_map") and any( |
| actor_model_config.architectures[0] in val for val in actor_model_config.auto_map.values() |
| ) |
| if has_remote_code: |
| auto_class = next( |
| k for k, v in actor_model_config.auto_map.items() if actor_model_config.architectures[0] in v |
| ) |
| match auto_class: |
| case "AutoModelForVision2Seq": |
| actor_module_class = AutoModelForVision2Seq |
| case "AutoModelForCausalLM": |
| actor_module_class = AutoModelForCausalLM |
| case "AutoModelForImageTextToText": |
| actor_module_class = AutoModelForImageTextToText |
| case _: |
| actor_module_class = AutoModel |
| else: |
| if type(actor_model_config) in AutoModelForVision2Seq._model_mapping.keys(): |
| actor_module_class = AutoModelForVision2Seq |
| elif type(actor_model_config) in AutoModelForCausalLM._model_mapping.keys(): |
| actor_module_class = AutoModelForCausalLM |
| elif type(actor_model_config) in AutoModelForImageTextToText._model_mapping.keys(): |
| actor_module_class = AutoModelForImageTextToText |
| else: |
| actor_module_class = AutoModel |
|
|
| actor_module = actor_module_class.from_pretrained( |
| pretrained_model_name_or_path=local_path, |
| torch_dtype=torch_dtype, |
| config=actor_model_config, |
| trust_remote_code=trust_remote_code, |
| attn_implementation=attn_implementation, |
| ) |
|
|
| |
| if use_liger: |
| from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance |
|
|
| _apply_liger_kernel_to_instance(model=actor_module) |
|
|
| fused_kernel_options = self.config.model.get("fused_kernel_options", None) |
| fused_kernels_backend = ( |
| fused_kernel_options.get("impl_backend", None) if fused_kernel_options is not None else None |
| ) |
|
|
| apply_monkey_patch( |
| model=actor_module, |
| use_remove_padding=use_remove_padding, |
| ulysses_sp_size=self.ulysses_sequence_parallel_size, |
| use_fused_kernels=use_fused_kernels, |
| fused_kernels_backend=fused_kernels_backend, |
| use_prefix_grouper=use_prefix_grouper, |
| use_tiled_mlp=use_tiled_mlp, |
| tiled_mlp_shards=tiled_mlp_shards, |
| ) |
|
|
| |
| actor_module.to(torch_dtype) |
|
|
| if enable_gradient_checkpointing: |
| actor_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) |
|
|
| if self._is_lora: |
| print("Applying LoRA to actor module") |
| actor_module.enable_input_require_grads() |
|
|
| lora_adapter_path = self.config.model.get("lora_adapter_path") |
| if lora_adapter_path is not None: |
| from peft import PeftModel |
|
|
| print(f"Loading pre-trained LoRA adapter to {role} from: {lora_adapter_path}") |
|
|
| |
| local_adapter_path = copy_to_local(lora_adapter_path, use_shm=self.config.model.get("use_shm", False)) |
|
|
| actor_module = PeftModel.from_pretrained(actor_module, local_adapter_path, is_trainable=True) |
| peft_config = actor_module.peft_config["default"] |
| |
| if isinstance(peft_config.task_type, str): |
| peft_config.task_type = TaskType.CAUSAL_LM |
|
|
| else: |
| |
| lora_config = { |
| "task_type": TaskType.CAUSAL_LM, |
| "r": self.config.model.lora_rank, |
| "lora_alpha": self.config.model.lora_alpha, |
| "target_modules": convert_to_regular_types(self.config.model.target_modules), |
| "exclude_modules": convert_to_regular_types(self.config.model.exclude_modules), |
| "bias": "none", |
| } |
| actor_module = get_peft_model(actor_module, LoraConfig(**lora_config)) |
|
|
| self.use_orig_params = fsdp_config.get("use_orig_params", False) |
| if self.config.actor.get("freeze_vision_tower", False): |
| vision_tower = get_vl_model_vision_tower(actor_module) |
| if vision_tower is not None: |
| vision_tower.requires_grad_(False) |
| self.use_orig_params = True |
| if self.rank == 0: |
| print("[actor model] Vision tower is set to not trainable.") |
| else: |
| if self.rank == 0: |
| print("[actor model] No vision tower found.") |
|
|
| |
| if role == "actor" and self._qat_enabled: |
| actor_module = apply_qat(actor_module, self.qat_config) |
| enable_qat_fuse(actor_module) |
| if self.qat_config.mode == "w4a4": |
| self._restore_w4a4_input_scales(actor_module, self.config.model.path) |
|
|
| torch.distributed.barrier() |
|
|
| if self.rank == 0: |
| print_model_size(actor_module) |
|
|
| log_gpu_memory_usage(f"After init {role} from HF AutoModel", logger=logger) |
|
|
| |
| mixed_precision_config = fsdp_config.get("mixed_precision", None) |
| if mixed_precision_config is not None: |
| param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16")) |
| reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32")) |
| buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32")) |
| else: |
| param_dtype = PrecisionType.to_dtype(fsdp_config.dtype) |
| reduce_dtype = torch.float32 |
| buffer_dtype = torch.float32 |
|
|
| mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype) |
|
|
| |
| self._param_dtype = param_dtype |
|
|
| auto_wrap_policy = get_fsdp_wrap_policy( |
| module=actor_module, |
| config=fsdp_config.get("wrap_policy", None), |
| is_lora=self._is_lora, |
| ) |
|
|
| |
| |
| |
| |
|
|
| if self.rank == 0: |
| print(f"wrap_policy: {auto_wrap_policy}") |
|
|
| fsdp_mesh = self.device_mesh |
| fsdp_enable_zero3 = fsdp_config.reshard_after_forward |
| sharding_strategy = get_sharding_strategy(fsdp_mesh, fsdp_enable_zero3) |
|
|
| |
| |
| |
| cpu_offload = None if role == "actor" else CPUOffload(offload_params=True) |
| fsdp_strategy = self.config.actor.strategy |
| if fsdp_strategy == "fsdp": |
| actor_module_fsdp = FSDP( |
| actor_module, |
| cpu_offload=cpu_offload, |
| param_init_fn=init_fn, |
| auto_wrap_policy=auto_wrap_policy, |
| device_id=get_device_id(), |
| sharding_strategy=sharding_strategy, |
| mixed_precision=mixed_precision, |
| sync_module_states=True, |
| device_mesh=self.device_mesh, |
| use_orig_params=self.use_orig_params, |
| forward_prefetch=fsdp_config.get("forward_prefetch", False), |
| ) |
| elif fsdp_strategy == "fsdp2": |
| assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" |
| mp_policy = MixedPrecisionPolicy( |
| param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True |
| ) |
| if role == "actor" and fsdp_config.offload_policy: |
| cpu_offload = CPUOffloadPolicy(pin_memory=True) |
| self._is_offload_param = False |
| self._is_offload_optimizer = False |
| else: |
| cpu_offload = None if role == "actor" else CPUOffloadPolicy(pin_memory=True) |
|
|
| fsdp_kwargs = { |
| "mesh": fsdp_mesh, |
| "mp_policy": mp_policy, |
| "offload_policy": cpu_offload, |
| "reshard_after_forward": fsdp_config.reshard_after_forward, |
| "shard_placement_fn": get_shard_placement_fn(fsdp_size=self.device_mesh.shape[-1]), |
| } |
| full_state = actor_module.state_dict() |
| apply_fsdp2(actor_module, fsdp_kwargs, fsdp_config) |
| fsdp2_load_full_state_dict(actor_module, full_state, fsdp_mesh, cpu_offload) |
| actor_module_fsdp = actor_module |
| else: |
| raise NotImplementedError(f"not implement {fsdp_strategy}") |
|
|
| if enable_activation_offload: |
| enable_activation_offloading(actor_module_fsdp, fsdp_strategy, enable_gradient_checkpointing) |
|
|
| log_gpu_memory_usage(f"After {role} FSDP init", logger=logger) |
|
|
| |
| if role == "actor" and optim_config is not None: |
| from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup |
|
|
| actor_optimizer = build_optimizer(actor_module_fsdp.parameters(), optim_config) |
|
|
| total_steps = optim_config.get("total_training_steps", 0) |
| num_warmup_steps = int(optim_config.get("lr_warmup_steps", -1)) |
| lr_scheduler_type = optim_config.get("lr_scheduler_type", "constant") |
| min_lr_ratio = optim_config.get("min_lr_ratio", 0.0) |
| num_cycles = optim_config.get("num_cycles", 0.5) |
| if num_warmup_steps < 0: |
| num_warmup_steps_ratio = optim_config.get("lr_warmup_steps_ratio", 0.0) |
| num_warmup_steps = int(num_warmup_steps_ratio * total_steps) |
|
|
| if self.rank == 0: |
| print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") |
|
|
| if lr_scheduler_type == "constant": |
| actor_lr_scheduler = get_constant_schedule_with_warmup( |
| optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps |
| ) |
| elif lr_scheduler_type == "cosine": |
| actor_lr_scheduler = get_cosine_schedule_with_warmup( |
| optimizer=actor_optimizer, |
| num_warmup_steps=num_warmup_steps, |
| num_training_steps=total_steps, |
| min_lr_ratio=min_lr_ratio, |
| num_cycles=num_cycles, |
| ) |
| else: |
| raise NotImplementedError(f"LR scheduler type {lr_scheduler_type} is not supported") |
|
|
| log_gpu_memory_usage(f"After {role} optimizer init", logger=logger) |
| else: |
| actor_optimizer = None |
| actor_lr_scheduler = None |
|
|
| return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config |
|
|
| def _build_rollout(self, trust_remote_code=False): |
| from torch.distributed.device_mesh import init_device_mesh |
|
|
| |
| rollout_config: RolloutConfig = omega_conf_to_dataclass(self.config.rollout) |
| model_config: HFModelConfig = omega_conf_to_dataclass(self.config.model, dataclass_type=HFModelConfig) |
| self.model_config = model_config |
|
|
| |
| infer_tp = self.config.rollout.tensor_model_parallel_size * self.config.rollout.data_parallel_size |
| infer_pp = self.config.rollout.pipeline_model_parallel_size |
| infer_world_size = infer_tp * infer_pp |
| dp = self.world_size // infer_world_size |
| assert self.world_size % infer_world_size == 0, ( |
| f"rollout world_size: {self.world_size} is not divisible by infer_world_size: {infer_world_size}" |
| ) |
| rollout_device_mesh = init_device_mesh( |
| device_name, mesh_shape=(dp, infer_tp, infer_pp), mesh_dim_names=["dp", "infer_tp", "infer_pp"] |
| ) |
| rollout_name = self.config.rollout.name |
|
|
| self.rollout_device_mesh = rollout_device_mesh |
|
|
| if rollout_name == "hf": |
| self._register_dispatch_collect_info("rollout", dp_rank=self.rank, is_collect=True) |
| else: |
| is_collect = ( |
| rollout_device_mesh["infer_tp"].get_local_rank() == 0 |
| and rollout_device_mesh["infer_pp"].get_local_rank() == 0 |
| ) |
| self._register_dispatch_collect_info( |
| "rollout", dp_rank=rollout_device_mesh["dp"].get_local_rank(), is_collect=is_collect |
| ) |
|
|
| |
| log_gpu_memory_usage(f"Before building {self.config.rollout.name} rollout", logger=logger) |
| self.rollout = get_rollout_class(rollout_config.name, rollout_config.mode)( |
| config=rollout_config, model_config=model_config, device_mesh=rollout_device_mesh |
| ) |
| log_gpu_memory_usage(f"After building {self.config.rollout.name} rollout", logger=logger) |
|
|
| |
| if torch.distributed.get_world_size() == 1 and fsdp_version(self.actor_module_fsdp) == 1: |
| FSDP.set_state_dict_type( |
| self.actor_module_fsdp, |
| state_dict_type=StateDictType.FULL_STATE_DICT, |
| state_dict_config=FullStateDictConfig(), |
| ) |
| elif fsdp_version(self.actor_module_fsdp) == 1: |
| FSDP.set_state_dict_type( |
| self.actor_module_fsdp, |
| state_dict_type=StateDictType.SHARDED_STATE_DICT, |
| state_dict_config=ShardedStateDictConfig(), |
| ) |
|
|
| |
| self.base_sync_done: bool = "dummy" not in self.config.rollout.load_format |
| self.layered_summon = self.config.rollout.get("layered_summon", False) |
|
|
| |
| |
| |
| |
|
|
| async def rollout_mode(self): |
| """Context switch hybridengine to rollout mode.""" |
| aggressive_empty_cache(force_sync=True) |
|
|
| log_gpu_memory_usage("Before load_fsdp_model_to_gpu", logger=logger) |
| if self._is_offload_param: |
| load_fsdp_model_to_gpu(self.actor_module_fsdp) |
| log_gpu_memory_usage("After load_fsdp_model_to_gpu", logger=logger) |
|
|
| peft_config = None |
| peft_model = getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp) |
| if hasattr(peft_model, "peft_config"): |
| peft_config = peft_model.peft_config.get("default", None) |
| params = collect_lora_params( |
| module=self.actor_module_fsdp, |
| layered_summon=self.config.rollout.get("layered_summon", False), |
| base_sync_done=self.base_sync_done, |
| ) |
| if not self.base_sync_done: |
| params = {replace_lora_wrapper(k, peft_config): v for k, v in params.items()} |
| else: |
| params = self.actor_module_fsdp.state_dict() |
|
|
| params = convert_weight_keys( |
| params, getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp) |
| ) |
|
|
| |
| |
| |
| |
| |
| if ( |
| peft_config is not None |
| and getattr(self.rollout, "sleep_level", None) == 2 |
| and self.config.rollout.free_cache_engine |
| ): |
| base_model_params = collect_lora_params( |
| module=self.actor_module_fsdp, |
| layered_summon=self.layered_summon, |
| base_sync_done=False, |
| ) |
| base_model_params = {replace_lora_wrapper(k, peft_config): v for k, v in base_model_params.items()} |
| base_model_params = convert_weight_keys( |
| base_model_params, getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp) |
| ) |
|
|
| log_gpu_memory_usage("Before offload_fsdp_model_to_cpu", logger=logger) |
| if self._is_offload_param: |
| offload_fsdp_model_to_cpu(self.actor_module_fsdp) |
| log_gpu_memory_usage("After offload_fsdp_model_to_cpu", logger=logger) |
|
|
| set_expandable_segments(False) |
|
|
| if peft_config is not None and self.base_sync_done: |
| per_tensor_param = params.items() if isinstance(params, dict) else params |
| else: |
| device = get_device_id() |
| per_tensor_param = ( |
| (name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param) |
| for name, param in params.items() |
| ) |
|
|
| |
| if self._qat_enabled: |
| from verl.utils.qat.quantizer import QATQuantizer |
|
|
| quantizer = QATQuantizer( |
| mode=self.qat_config.mode, |
| group_size=self.qat_config.group_size, |
| ignore_patterns=self.qat_config.ignore_patterns, |
| device=torch.device(get_device_id()), |
| param_dtype=self._param_dtype, |
| ) |
| per_tensor_param = quantizer.quantize_with_fusion( |
| per_tensor_param, |
| target_device=torch.device("cpu"), |
| ) |
| aggressive_empty_cache(force_sync=True) |
|
|
| if self.config.rollout.free_cache_engine: |
| await self.rollout.resume(tags=["weights"]) |
| log_gpu_memory_usage("After resume weights", logger=logger) |
|
|
| if ( |
| peft_config is not None |
| and getattr(self.rollout, "sleep_level", None) == 2 |
| and self.config.rollout.free_cache_engine |
| ): |
| per_tensor_base_params = ( |
| (name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param) |
| for name, param in base_model_params.items() |
| ) |
| await self.rollout.update_weights(per_tensor_base_params, base_sync_done=False) |
| del base_model_params, per_tensor_base_params |
|
|
| await self.rollout.update_weights(per_tensor_param, peft_config=peft_config, base_sync_done=self.base_sync_done) |
| log_gpu_memory_usage("After update_weights", logger=logger) |
| del params, per_tensor_param |
| aggressive_empty_cache(force_sync=True) |
| if self.config.rollout.free_cache_engine: |
| await self.rollout.resume(tags=["kv_cache"]) |
| log_gpu_memory_usage("After resume kv_cache", logger=logger) |
|
|
| self.base_sync_done = True |
| set_expandable_segments(True) |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def init_model(self): |
| from verl.workers.actor import DataParallelPPOActor |
|
|
| |
| import_external_libs(self.config.model.get("external_lib", None)) |
|
|
| |
| self._init_qat_config() |
|
|
| override_model_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {}))) |
| use_remove_padding = self.config.model.get("use_remove_padding", False) |
| use_shm = self.config.model.get("use_shm", False) |
| use_fused_kernels = self.config.model.get("use_fused_kernels", False) |
|
|
| if self._is_actor or self._is_rollout: |
| |
| if self._is_actor: |
| optim_config = self.config.actor.optim |
| fsdp_config = omega_conf_to_dataclass(self.config.actor.fsdp_config) |
| else: |
| optim_config = None |
| fsdp_config = FSDPEngineConfig() |
|
|
| local_path = copy_to_local(self.config.model.path, use_shm=use_shm) |
| |
| tiled_mlp_config = self.config.model.get("tiled_mlp", {}) |
| use_tiled_mlp = tiled_mlp_config.get("enabled", False) |
| tiled_mlp_shards = tiled_mlp_config.get("num_shards", 4) |
|
|
| ( |
| self.actor_module_fsdp, |
| self.actor_optimizer, |
| self.actor_lr_scheduler, |
| self.actor_model_config, |
| ) = self._build_model_optimizer( |
| model_path=local_path, |
| fsdp_config=fsdp_config, |
| optim_config=optim_config, |
| override_model_config=override_model_config, |
| use_remove_padding=use_remove_padding, |
| use_fused_kernels=use_fused_kernels, |
| enable_gradient_checkpointing=self.config.model.get("enable_gradient_checkpointing", False), |
| trust_remote_code=self.config.model.get("trust_remote_code", False), |
| use_liger=self.config.model.get("use_liger", False), |
| role="actor", |
| enable_activation_offload=self.config.model.get("enable_activation_offload", False), |
| use_prefix_grouper=self.config.actor.get("use_prefix_grouper", False), |
| use_tiled_mlp=use_tiled_mlp, |
| tiled_mlp_shards=tiled_mlp_shards, |
| ) |
|
|
| |
| if fsdp_version(self.actor_module_fsdp) == 1: |
| self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module |
|
|
| if self._is_offload_param: |
| offload_fsdp_model_to_cpu(self.actor_module_fsdp) |
| log_gpu_memory_usage("After offload actor model during init", logger=logger) |
|
|
| if self._is_offload_optimizer: |
| offload_fsdp_optimizer(optimizer=self.actor_optimizer) |
| log_gpu_memory_usage("After offload actor optimizer during init", logger=logger) |
|
|
| if self._is_actor: |
| actor_cfg = omega_conf_to_dataclass(self.config.actor) |
| self.actor = DataParallelPPOActor( |
| config=actor_cfg, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer |
| ) |
|
|
| if self._is_rollout: |
| self._build_rollout(trust_remote_code=self.config.model.get("trust_remote_code", False)) |
|
|
| if self._is_ref: |
| ref_model_path = self.config.model.path |
| ref_model = self.config.ref.get("model", None) |
| if ref_model is not None: |
| ref_model_path = ref_model.get("path", self.config.model.path) |
|
|
| if self.rank == 0: |
| print("reference model:", ref_model_path) |
| local_path = copy_to_local(ref_model_path, use_shm=use_shm) |
| use_prefix_grouper = hasattr(self.config, "actor") and self.config.actor.get("use_prefix_grouper", False) |
|
|
| |
| ref_tiled_mlp_config = self.config.ref.get("tiled_mlp", None) |
| if ref_tiled_mlp_config is None: |
| ref_tiled_mlp_config = self.config.model.get("tiled_mlp", {}) |
| ref_use_tiled_mlp = ref_tiled_mlp_config.get("enabled", False) |
| ref_tiled_mlp_shards = ref_tiled_mlp_config.get("num_shards", 4) |
|
|
| self.ref_module_fsdp = self._build_model_optimizer( |
| model_path=local_path, |
| fsdp_config=omega_conf_to_dataclass(self.config.ref.fsdp_config), |
| optim_config=None, |
| override_model_config=override_model_config, |
| use_remove_padding=use_remove_padding, |
| use_fused_kernels=use_fused_kernels, |
| trust_remote_code=self.config.model.get("trust_remote_code", False), |
| use_liger=self.config.model.get("use_liger", False), |
| role="ref", |
| use_prefix_grouper=use_prefix_grouper, |
| use_tiled_mlp=ref_use_tiled_mlp, |
| tiled_mlp_shards=ref_tiled_mlp_shards, |
| )[0] |
| OmegaConf.set_struct(self.config.ref, True) |
| with open_dict(self.config.ref): |
| self.config.ref.use_remove_padding = use_remove_padding |
| self.config.ref.use_fused_kernels = use_fused_kernels |
| if use_prefix_grouper: |
| self.config.ref.use_prefix_grouper = use_prefix_grouper |
| self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp) |
|
|
| if self._is_actor: |
| self.flops_counter = FlopsCounter(self.actor_model_config) |
| self.checkpoint_manager = FSDPCheckpointManager( |
| model=self.actor_module_fsdp, |
| optimizer=self.actor.actor_optimizer, |
| lr_scheduler=self.actor_lr_scheduler, |
| processing_class=self.processor if self.processor is not None else self.tokenizer, |
| checkpoint_config=self.config.actor.checkpoint, |
| trust_remote_code=self.config.model.get("trust_remote_code", False), |
| ) |
|
|
| if not self._is_actor and self._is_rollout: |
| |
| |
|
|
| checkpoint_contents = OmegaConf.create({"load_contents": ["model"], "save_contents": []}) |
| self.checkpoint_manager = FSDPCheckpointManager( |
| model=self.actor_module_fsdp, |
| optimizer=None, |
| lr_scheduler=None, |
| processing_class=self.processor if self.processor is not None else self.tokenizer, |
| checkpoint_config=checkpoint_contents, |
| ) |
|
|
| |
| aggressive_empty_cache(force_sync=True) |
|
|
| @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) |
| @DistProfiler.annotate(color="red", role="actor_update") |
| def update_actor(self, data: DataProto): |
| assert self._is_actor |
| if self._is_offload_param: |
| load_fsdp_model_to_gpu(self.actor_module_fsdp) |
| if self._is_offload_optimizer: |
| load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=get_device_id()) |
|
|
| with self.ulysses_sharding_manager: |
| data = data.to("cpu") |
| data.meta_info.setdefault("pad_token_id", self.tokenizer.pad_token_id) |
| |
| with Timer(name="update_policy", logger=None) as timer: |
| metrics = self.actor.update_policy(data=data) |
| delta_time = timer.last |
| global_num_tokens = data.meta_info["global_token_num"] |
| images_seqlens = data.meta_info.get("images_seqlens", None) |
| estimated_flops, promised_flops = self.flops_counter.estimate_flops( |
| global_num_tokens, delta_time, images_seqlens=images_seqlens |
| ) |
| metrics["perf/mfu/actor"] = ( |
| estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size |
| ) |
| metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (1024**3) |
| metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / (1024**3) |
| metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3) |
|
|
| lr = self.actor_lr_scheduler.get_last_lr()[0] |
| metrics["actor/lr"] = lr.item() if torch.is_tensor(lr) else lr |
| self.actor_lr_scheduler.step() |
|
|
| |
| output = DataProto(meta_info={"metrics": metrics}) |
|
|
| output = output.to("cpu") |
|
|
| if self._is_offload_param: |
| offload_fsdp_model_to_cpu(self.actor_module_fsdp) |
| log_gpu_memory_usage("After offload actor model during update_actor", logger=logger) |
| if self._is_offload_optimizer: |
| offload_fsdp_optimizer(optimizer=self.actor_optimizer) |
| log_gpu_memory_usage("After offload actor optimizer during update_actor", logger=logger) |
|
|
| return output |
|
|
| @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="rollout")) |
| @DistProfiler.annotate(color="red", role="rollout_generate") |
| def generate_sequences(self, prompts: DataProto): |
| |
| assert self._is_rollout |
| prompts = prompts.to(get_device_id()) |
|
|
| meta_info = { |
| "eos_token_id": self.generation_config.eos_token_id |
| if self.generation_config is not None |
| else self.tokenizer.eos_token_id, |
| "pad_token_id": self.generation_config.pad_token_id |
| if self.generation_config is not None |
| else self.tokenizer.pad_token_id, |
| } |
| prompts.meta_info.update(meta_info) |
|
|
| timing_generate = {} |
| if self._is_actor: |
| loop = get_event_loop() |
| loop.run_until_complete(self.rollout_mode()) |
| log_gpu_memory_usage("After switch to rollout mode", logger=logger) |
|
|
| with simple_timer("generate_sequences", timing_generate): |
| output = self.rollout.generate_sequences(prompts=prompts) |
|
|
| if self._is_actor: |
| loop.run_until_complete(self.trainer_mode()) |
| log_gpu_memory_usage("After switch to trainer mode", logger=logger) |
|
|
| |
| |
| timing_generate_topk_ratio, timing_generate_min, timing_generate_max = topk_reduce_ratio_min_max( |
| timing_generate["generate_sequences"] |
| ) |
| timing_generate = reduce_timing(timing_generate) |
| timing_generate.update( |
| { |
| "generation_timing/max": timing_generate_max, |
| "generation_timing/min": timing_generate_min, |
| "generation_timing/topk_ratio": timing_generate_topk_ratio, |
| } |
| ) |
| output.meta_info["timing"] = timing_generate |
| output = output.to("cpu") |
|
|
| |
| get_torch_device().empty_cache() |
| return output |
|
|
| @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) |
| @DistProfiler.annotate(color="blue", role="actor_compute_log_prob") |
| def compute_log_prob(self, data: DataProto): |
| |
| |
| assert self._is_actor |
| if self._is_offload_param: |
| load_fsdp_model_to_gpu(self.actor_module_fsdp) |
|
|
| |
| from contextlib import nullcontext |
|
|
| is_lora = data.meta_info.pop("is_lora", False) |
| adapter_ctx = self.actor.actor_module.disable_adapter() if is_lora else nullcontext() |
| |
| config_source = self.config.ref if is_lora else self.config.rollout |
| data.meta_info["micro_batch_size"] = config_source.log_prob_micro_batch_size_per_gpu |
| data.meta_info["max_token_len"] = config_source.log_prob_max_token_len_per_gpu |
| data.meta_info["use_dynamic_bsz"] = config_source.log_prob_use_dynamic_bsz |
| data.meta_info["temperature"] = self.config.rollout.temperature |
| data.meta_info.setdefault("pad_token_id", self.tokenizer.pad_token_id) |
| |
| calculate_entropy = not is_lora |
| with self.ulysses_sharding_manager: |
| with adapter_ctx: |
| outputs = self.actor.compute_log_prob(data=data, calculate_entropy=calculate_entropy) |
| if not is_lora: |
| tensors = {"old_log_probs": outputs["log_probs"]} |
| else: |
| tensors = {"ref_log_prob": outputs["log_probs"]} |
| if calculate_entropy: |
| tensors["entropys"] = outputs["entropys"] |
| if "sum_pi_squared" in outputs: |
| tensors["sum_pi_squared"] = outputs["sum_pi_squared"] |
| output = DataProto.from_dict( |
| tensors=tensors, |
| meta_info={"temperature": self.config.rollout.temperature}, |
| ) |
|
|
| output = output.to("cpu") |
|
|
| |
| |
| if self.world_size > 1 and fsdp_version(self.actor.actor_module) == 1: |
| self.actor.actor_module._handle.reshard(True) |
|
|
| if self._is_offload_param: |
| offload_fsdp_model_to_cpu(self.actor_module_fsdp) |
| log_gpu_memory_usage("After offload actor model during compute_log_prob", logger=logger) |
|
|
| return output |
|
|
| @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) |
| @DistProfiler.annotate(color="olive", role="ref_compute_log_prob") |
| def compute_ref_log_prob(self, data: DataProto): |
| if self._is_lora: |
| |
| data.meta_info["is_lora"] = True |
| return self.compute_log_prob(data) |
| assert self._is_ref |
| |
| |
|
|
| micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu |
| data.meta_info["micro_batch_size"] = micro_batch_size |
| data.meta_info["temperature"] = self.config.rollout.temperature |
| data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu |
| data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz |
| data.meta_info.setdefault("pad_token_id", self.tokenizer.pad_token_id) |
| with self.ulysses_sharding_manager: |
| data = data.to("cpu") |
| outputs = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False) |
| output = DataProto.from_dict(tensors={"ref_log_prob": outputs["log_probs"]}) |
|
|
| output = output.to("cpu") |
|
|
| |
| |
| if self.world_size > 1: |
| if fsdp_version(self.ref_policy.actor_module) == 1: |
| self.ref_policy.actor_module._handle.reshard(True) |
| elif fsdp_version(self.ref_policy.actor_module) == 2: |
| self.ref_policy.actor_module.reshard() |
|
|
| return output |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): |
| from verl.utils.logger import log_with_rank |
|
|
| |
| assert self._is_actor |
|
|
| if self._is_offload_param: |
| load_fsdp_model_to_gpu(self.actor_module_fsdp) |
|
|
| self.checkpoint_manager.save_checkpoint( |
| local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep |
| ) |
| dist.barrier() |
|
|
| if self._is_lora and hasattr(getattr(self, "actor_module", self.actor_module_fsdp), "peft_config"): |
| lora_save_path = os.path.join(local_path, "lora_adapter") |
| peft_model = getattr(self, "actor_module", self.actor_module_fsdp) |
| peft_config = {} |
| if dist.get_rank() == 0: |
| os.makedirs(lora_save_path, exist_ok=True) |
| peft_config = asdict(peft_model.peft_config.get("default", {})) |
| peft_config["task_type"] = peft_config["task_type"].value |
| peft_config["peft_type"] = peft_config["peft_type"].value |
| peft_config["target_modules"] = list(peft_config["target_modules"]) |
| try: |
| if fsdp_version(self.actor_module_fsdp) > 0: |
| self.actor_module_fsdp = self.actor_module_fsdp.to(get_device_name()) |
| lora_params = layered_summon_lora_params(self.actor_module_fsdp) |
| if dist.get_rank() == 0: |
| save_file(lora_params, os.path.join(lora_save_path, "adapter_model.safetensors")) |
| with open(os.path.join(lora_save_path, "adapter_config.json"), "w", encoding="utf-8") as f: |
| json.dump(peft_config, f, ensure_ascii=False, indent=4) |
| except Exception as e: |
| log_with_rank( |
| f"Save LoRA Adapter Error ({e})", rank=dist.get_rank(), logger=logger, log_only_rank_0=True |
| ) |
|
|
| dist.barrier() |
| log_with_rank( |
| f"[rank-{self.rank}]: Saved LoRA adapter to: {lora_save_path}", |
| rank=dist.get_rank(), |
| logger=logger, |
| log_only_rank_0=True, |
| ) |
|
|
| if self._is_offload_param: |
| offload_fsdp_model_to_cpu(self.actor_module_fsdp) |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False): |
| assert self._is_actor or (not self._is_actor and self._is_rollout), ( |
| f"Checkpoint loading is only supported for Actor or standalone Rollout Workers, but got " |
| f"{self._is_actor} and {self._is_rollout}" |
| ) |
|
|
| |
| if local_path is None: |
| if self._is_offload_param: |
| offload_fsdp_model_to_cpu(self.actor_module_fsdp) |
| if self._is_offload_optimizer: |
| offload_fsdp_optimizer(self.actor_optimizer) |
| return |
|
|
| if self._is_offload_param: |
| load_fsdp_model_to_gpu(self.actor_module_fsdp) |
|
|
| self.checkpoint_manager.load_checkpoint( |
| local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load |
| ) |
|
|
| if self._is_offload_param: |
| offload_fsdp_model_to_cpu(self.actor_module_fsdp) |
|
|
| if self._is_offload_optimizer: |
| offload_fsdp_optimizer(self.actor_optimizer) |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def start_profile(self, **kwargs) -> None: |
| """Start profiling for the current rank in the current training step.""" |
| self.profiler.start(**kwargs) |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def stop_profile(self) -> None: |
| """Stop profiling for the current rank in the current training step.""" |
| self.profiler.stop() |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def dump_memory_snapshot(self, tag: str = "manual", sub_dir: str = None) -> None: |
| """Manually trigger a CUDA memory snapshot dump on all ranks.""" |
| |
| |
| if hasattr(self, "profiler") and hasattr(self.profiler, "_impl"): |
| try: |
| |
| if hasattr(self.profiler._impl, "sampler"): |
| out_dir = OmegaConf.select(self.config, "actor.profiler.save_path") or "." |
| self.profiler._impl.sampler.dump_memory_snapshot(out_dir=out_dir, tag=tag, sub_dir=sub_dir) |
| except Exception: |
| |
| pass |
|
|
|
|
| class CriticWorker(Worker, DistProfilerExtension): |
| def __init__(self, config: FSDPCriticConfig): |
| Worker.__init__(self) |
| omega_profiler_config = config.get("profiler", {}) |
| profiler_config = omega_conf_to_dataclass(omega_profiler_config, dataclass_type=ProfilerConfig) |
| if omega_profiler_config.get("tool", None) in ["npu", "nsys", "torch", "torch_memory"]: |
| tool_config = omega_conf_to_dataclass( |
| omega_profiler_config.get("tool_config", {}).get(omega_profiler_config.get("tool")) |
| ) |
| else: |
| tool_config = None |
| DistProfilerExtension.__init__( |
| self, DistProfiler(rank=self.rank, config=profiler_config, tool_config=tool_config) |
| ) |
| import torch.distributed |
|
|
| self.config = config |
| if not torch.distributed.is_initialized(): |
| torch.distributed.init_process_group( |
| backend=get_nccl_backend(), |
| timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), |
| init_method=os.environ.get("DIST_INIT_METHOD", None), |
| ) |
| self.config: FSDPCriticConfig = config |
|
|
| |
| world_size = torch.distributed.get_world_size() |
| from torch.distributed.device_mesh import init_device_mesh |
|
|
| fsdp_size = self.config.model.fsdp_config.fsdp_size |
| self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) |
|
|
| self.ulysses_device_mesh = None |
| self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) |
| dp = world_size // self.ulysses_sequence_parallel_size |
| if self.ulysses_sequence_parallel_size > 1: |
| self.ulysses_device_mesh = init_device_mesh( |
| device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] |
| ) |
|
|
| |
| if self.ulysses_device_mesh is not None: |
| is_collect = self.ulysses_device_mesh["sp"].get_local_rank() == 0 |
| self._register_dispatch_collect_info( |
| "critic", dp_rank=self.ulysses_device_mesh["dp"].get_local_rank(), is_collect=is_collect |
| ) |
| else: |
| self._register_dispatch_collect_info("critic", dp_rank=self.rank, is_collect=True) |
|
|
| self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) |
|
|
| |
| self._is_offload_param = self.config.model.fsdp_config.param_offload |
| self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload |
|
|
| |
| self.config.ppo_mini_batch_size *= self.config.rollout_n |
| self.config.ppo_mini_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size |
| if self.config.ppo_micro_batch_size is not None: |
| self.config.ppo_micro_batch_size //= ( |
| torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size |
| ) |
| self.config.forward_micro_batch_size //= ( |
| torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size |
| ) |
| self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size |
| self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size |
|
|
| if self.config.ppo_micro_batch_size_per_gpu is not None: |
| assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0, ( |
| f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by " |
| f"ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" |
| ) |
| assert self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu > 0, ( |
| f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be larger than " |
| f"ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" |
| ) |
| self._is_lora = ( |
| self.config.model.get("lora_adapter_path") is not None or self.config.model.get("lora_rank", 0) > 0 |
| ) |
| self.use_orig_params = self.config.model.fsdp_config.get("use_orig_params", False) |
|
|
| def _build_critic_model_optimizer(self, config: FSDPCriticConfig): |
| |
| from torch.distributed.fsdp import MixedPrecision |
|
|
| from verl.utils.model import load_valuehead_model, print_model_size |
| from verl.utils.torch_dtypes import PrecisionType |
|
|
| use_shm = config.model.get("use_shm", False) |
| local_path = copy_to_local(config.model.path, use_shm=use_shm) |
| |
| |
|
|
| tokenizer_path = copy_to_local(config.model.tokenizer_path, use_shm=use_shm) |
| self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False)) |
| self.processor = hf_processor(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False)) |
|
|
| if self.config.model.get("custom_chat_template", None) is not None: |
| if self.processor is not None: |
| self.processor.chat_template = self.config.model.custom_chat_template |
| else: |
| self.tokenizer.chat_template = self.config.model.custom_chat_template |
| override_config = OmegaConf.to_container(OmegaConf.create(self.config.model.get("override_config", {}))) |
| override_config_kwargs = { |
| "bos_token_id": self.tokenizer.bos_token_id, |
| "eos_token_id": self.tokenizer.eos_token_id, |
| "pad_token_id": self.tokenizer.pad_token_id, |
| } |
| override_config_kwargs.update(override_config) |
| if self.rank == 0: |
| print(f"Critic overriding config {override_config_kwargs}") |
|
|
| torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32") |
| torch_dtype = PrecisionType.to_dtype(torch_dtype) |
|
|
| from transformers import AutoConfig |
|
|
| |
| attn_implementation = override_config.get("attn_implementation", "flash_attention_2") |
| critic_model_config = AutoConfig.from_pretrained( |
| local_path, |
| attn_implementation=attn_implementation, |
| trust_remote_code=config.model.get("trust_remote_code", False), |
| ) |
| |
| |
| |
| if self.ulysses_sequence_parallel_size > 1 and hasattr(critic_model_config, "vision_config"): |
| critic_model_config.vision_config._attn_implementation = "eager" |
|
|
| critic_model_config.num_labels = 1 |
| |
| if getattr(critic_model_config, "model_type", None) == "kimi_vl": |
| critic_model_config.text_config.topk_method = "greedy" |
|
|
| init_context = get_init_weight_context_manager( |
| use_meta_tensor=not critic_model_config.tie_word_embeddings, mesh=self.device_mesh |
| ) |
|
|
| |
| tiled_mlp_config = config.model.get("tiled_mlp", {}) |
| use_tiled_mlp = tiled_mlp_config.get("enabled", False) |
| tiled_mlp_shards = tiled_mlp_config.get("num_shards", 4) |
|
|
| |
| if use_tiled_mlp and config.strategy == "fsdp": |
| raise ValueError("TiledMLP requires FSDP2. Set `critic.strategy=fsdp2`.") |
|
|
| with init_context(), warnings.catch_warnings(): |
| warnings.simplefilter("ignore") |
| critic_model_config.classifier_dropout = 0.0 |
| critic_model_config.hidden_dropout = "0" |
| critic_model_config.summary_dropout_prob = 0.0 |
|
|
| critic_module = load_valuehead_model( |
| local_path, |
| torch_dtype, |
| critic_model_config, |
| config.model.get("trust_remote_code", False), |
| ) |
|
|
| use_remove_padding = config.model.get("use_remove_padding", False) |
|
|
| apply_monkey_patch( |
| model=critic_module, |
| use_remove_padding=use_remove_padding, |
| ulysses_sp_size=self.ulysses_sequence_parallel_size, |
| use_tiled_mlp=use_tiled_mlp, |
| tiled_mlp_shards=tiled_mlp_shards, |
| ) |
|
|
| |
| critic_module.to(torch_dtype) |
|
|
| if config.model.get("enable_gradient_checkpointing", False): |
| critic_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) |
|
|
| if self._is_lora: |
| print("Applying LoRA to critic module") |
| critic_module.enable_input_require_grads() |
|
|
| |
| lora_adapter_path = self.config.model.get("lora_adapter_path") |
| if lora_adapter_path is not None: |
| from peft import PeftModel |
|
|
| print(f"Loading pre-trained LoRA adapter to critic from: {lora_adapter_path}") |
|
|
| |
| local_adapter_path = copy_to_local(lora_adapter_path, use_shm=self.config.model.get("use_shm", False)) |
|
|
| critic_module = PeftModel.from_pretrained(critic_module, local_adapter_path, is_trainable=True) |
| peft_config = critic_module.peft_config["default"] |
| |
| |
| if isinstance(peft_config.task_type, str): |
| peft_config.task_type = TaskType.TOKEN_CLS |
|
|
| else: |
| |
| |
| lora_config = { |
| "task_type": TaskType.TOKEN_CLS, |
| "r": self.config.model.lora_rank, |
| "lora_alpha": self.config.model.lora_alpha, |
| "target_modules": convert_to_regular_types(self.config.model.target_modules), |
| "bias": "none", |
| } |
| critic_module = get_peft_model(critic_module, LoraConfig(**lora_config)) |
|
|
| if self.rank == 0: |
| print_model_size(critic_module) |
|
|
| self.critic_model_config = critic_model_config |
|
|
| fsdp_config = self.config.model.fsdp_config |
| mixed_precision_config = fsdp_config.get("mixed_precision", None) |
| if mixed_precision_config is not None: |
| param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16")) |
| reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32")) |
| buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32")) |
| else: |
| param_dtype = torch.bfloat16 |
| reduce_dtype = torch.float32 |
| buffer_dtype = torch.float32 |
|
|
| mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype) |
|
|
| auto_wrap_policy = get_fsdp_wrap_policy( |
| module=critic_module, |
| config=self.config.model.fsdp_config.wrap_policy, |
| is_lora=self._is_lora, |
| ) |
|
|
| log_gpu_memory_usage("Before critic FSDP", logger=None) |
|
|
| fsdp_mesh = self.device_mesh |
| sharding_strategy = get_sharding_strategy(fsdp_mesh) |
|
|
| self.use_orig_params = fsdp_config.get("use_orig_params", False) |
| if self.config.model.get("freeze_vision_tower", False): |
| vision_tower = get_vl_model_vision_tower(critic_module) |
| if vision_tower is not None: |
| vision_tower.requires_grad_(False) |
| self.use_orig_params = True |
| if self.rank == 0: |
| print("[critic model] Vision tower is set to not trainable.") |
| else: |
| if self.rank == 0: |
| print("[critic model] No vision tower found.") |
|
|
| |
| if config.strategy == "fsdp": |
| critic_module = FSDP( |
| critic_module, |
| param_init_fn=init_fn, |
| use_orig_params=self.use_orig_params, |
| auto_wrap_policy=auto_wrap_policy, |
| device_id=get_device_id(), |
| sharding_strategy=sharding_strategy, |
| mixed_precision=mixed_precision, |
| sync_module_states=True, |
| forward_prefetch=self.config.model.fsdp_config.forward_prefetch, |
| device_mesh=self.device_mesh, |
| cpu_offload=None, |
| ) |
| elif config.strategy == "fsdp2": |
| assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" |
| mp_policy = MixedPrecisionPolicy( |
| param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True |
| ) |
| offload_policy = None |
| if fsdp_config.offload_policy: |
| self._is_offload_param = False |
| self._is_offload_optimizer = False |
| offload_policy = CPUOffloadPolicy(pin_memory=True) |
|
|
| fsdp_kwargs = { |
| "mesh": fsdp_mesh, |
| "mp_policy": mp_policy, |
| "offload_policy": offload_policy, |
| "reshard_after_forward": fsdp_config.reshard_after_forward, |
| "shard_placement_fn": get_shard_placement_fn(fsdp_size=self.device_mesh.shape[-1]), |
| } |
| full_state = critic_module.state_dict() |
| apply_fsdp2(critic_module, fsdp_kwargs, fsdp_config) |
| fsdp2_load_full_state_dict(critic_module, full_state, fsdp_mesh, offload_policy) |
| else: |
| raise NotImplementedError(f"Unknown strategy {config.strategy}") |
|
|
| if config.model.get("enable_activation_offload", False): |
| enable_gradient_checkpointing = config.model.get("enable_gradient_checkpointing", False) |
| enable_activation_offloading(critic_module, config.strategy, enable_gradient_checkpointing) |
|
|
| log_gpu_memory_usage("After critic FSDP", logger=None) |
|
|
| critic_optimizer = build_optimizer(critic_module.parameters(), config.optim) |
|
|
| total_steps = config.optim.get("total_training_steps", 0) |
| num_warmup_steps = int(config.optim.get("lr_warmup_steps", -1)) |
|
|
| lr_scheduler_type = config.optim.get("lr_scheduler_type", "constant") |
| if num_warmup_steps < 0: |
| num_warmup_steps_ratio = config.optim.get("lr_warmup_steps_ratio", 0.0) |
| num_warmup_steps = int(num_warmup_steps_ratio * total_steps) |
|
|
| if self.rank == 0: |
| print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") |
|
|
| from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup |
|
|
| if lr_scheduler_type == "constant": |
| critic_lr_scheduler = get_constant_schedule_with_warmup( |
| optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps |
| ) |
| elif lr_scheduler_type == "cosine": |
| min_lr_ratio = config.optim.get("min_lr_ratio", 0.0) |
| num_cycles = config.optim.get("num_cycles", 0.5) |
| critic_lr_scheduler = get_cosine_schedule_with_warmup( |
| optimizer=critic_optimizer, |
| num_warmup_steps=num_warmup_steps, |
| num_training_steps=total_steps, |
| min_lr_ratio=min_lr_ratio, |
| num_cycles=num_cycles, |
| ) |
| else: |
| raise NotImplementedError(f"LR scheduler type {lr_scheduler_type} is not supported") |
|
|
| return critic_module, critic_optimizer, critic_lr_scheduler |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def init_model(self): |
| |
| import_external_libs(self.config.model.get("external_lib", None)) |
|
|
| from verl.workers.critic import DataParallelPPOCritic |
|
|
| self.critic_module, self.critic_optimizer, self.critic_lr_scheduler = self._build_critic_model_optimizer( |
| self.config |
| ) |
|
|
| if self._is_offload_param: |
| offload_fsdp_model_to_cpu(self.critic_module) |
| log_gpu_memory_usage("After offload critic model during init", logger=logger) |
| if self._is_offload_optimizer: |
| offload_fsdp_optimizer(optimizer=self.critic_optimizer) |
| log_gpu_memory_usage("After offload critic optimizer during init", logger=logger) |
|
|
| self.critic = DataParallelPPOCritic( |
| config=self.config, critic_module=self.critic_module, critic_optimizer=self.critic_optimizer |
| ) |
|
|
| self.flops_counter = FlopsCounter(self.critic_model_config) |
| self.checkpoint_manager = FSDPCheckpointManager( |
| model=self.critic_module, |
| optimizer=self.critic_optimizer, |
| lr_scheduler=self.critic_lr_scheduler, |
| processing_class=self.processor if self.processor is not None else self.tokenizer, |
| checkpoint_config=self.config.checkpoint, |
| trust_remote_code=self.config.model.get("trust_remote_code", False), |
| ) |
|
|
| @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic")) |
| @DistProfiler.annotate(color="cyan", role="compute_values") |
| def compute_values(self, data: DataProto): |
| if self._is_offload_param: |
| load_fsdp_model_to_gpu(self.critic_module) |
| micro_batch_size = self.config.forward_micro_batch_size_per_gpu |
| data.meta_info["micro_batch_size"] = micro_batch_size |
| data.meta_info["max_token_len"] = self.config.forward_max_token_len_per_gpu |
| data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz |
| |
| with self.ulysses_sharding_manager: |
| data = data.to("cpu") |
| values = self.critic.compute_values(data=data) |
| output = DataProto.from_dict(tensors={"values": values}) |
|
|
| output = output.to("cpu") |
| if self._is_offload_param: |
| offload_fsdp_model_to_cpu(self.critic_module) |
| return output |
|
|
| @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="critic")) |
| @DistProfiler.annotate(color="pink", role="critic_update") |
| def update_critic(self, data: DataProto): |
| if self._is_offload_param: |
| load_fsdp_model_to_gpu(self.critic_module) |
| if self._is_offload_optimizer: |
| load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=get_device_id()) |
|
|
| |
| with self.ulysses_sharding_manager: |
| data = data.to("cpu") |
| with Timer(name="update_critic", logger=None) as timer: |
| metrics = self.critic.update_critic(data=data) |
| delta_time = timer.last |
|
|
| global_num_tokens = data.meta_info["global_token_num"] |
| estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) |
| metrics["perf/mfu/critic"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size |
|
|
| lr = self.critic_lr_scheduler.get_last_lr()[0] |
| metrics["critic/lr"] = lr |
| self.critic_lr_scheduler.step() |
|
|
| output = DataProto(batch=None, meta_info={"metrics": metrics}) |
|
|
| if self._is_offload_param: |
| offload_fsdp_model_to_cpu(self.critic_module) |
| if self._is_offload_optimizer: |
| offload_fsdp_optimizer(optimizer=self.critic_optimizer) |
|
|
| output = output.to("cpu") |
| return output |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): |
| import torch |
|
|
| if self._is_offload_param: |
| load_fsdp_model_to_gpu(self.critic_module) |
|
|
| self.checkpoint_manager.save_checkpoint( |
| local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep |
| ) |
|
|
| torch.distributed.barrier() |
| if self._is_offload_param: |
| offload_fsdp_model_to_cpu(self.critic_module) |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True): |
| import torch |
|
|
| if self._is_offload_param: |
| load_fsdp_model_to_gpu(self.critic_module) |
|
|
| self.checkpoint_manager.load_checkpoint( |
| local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load |
| ) |
|
|
| torch.distributed.barrier() |
| if self._is_offload_param: |
| offload_fsdp_model_to_cpu(self.critic_module) |
|
|
| if self._is_offload_optimizer: |
| offload_fsdp_optimizer(self.critic_optimizer) |
|
|
|
|
| |
| class AsyncActorRolloutRefWorker(ActorRolloutRefWorker): |
| @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) |
| async def update_weights(self, global_steps: int = None): |
| await self.rollout_mode() |
| return True |
|
|