TTI / Dev /verl /workers /fsdp_workers.py
JosephBai's picture
Upload folder using huggingface_hub
857c2e9 verified
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The main entry point to run the PPO algorithm
"""
import os
import logging
import warnings
import ray
import torch
import torch.distributed
from omegaconf import DictConfig, open_dict
from transformers import AutoModelForCausalLM
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import register, Dispatch
import verl.utils.torch_functional as verl_F
from verl import DataProto
from verl.utils.model import compute_position_id_with_mask
from verl.utils.fs import copy_local_path_from_hdfs
from verl.utils.fsdp_utils import get_fsdp_wrap_policy, load_fsdp_grad, offload_fsdp_grad, init_fn, get_init_weight_context_manager, get_fsdp_wrap_policy_vla
from verl.utils.fsdp_utils import offload_fsdp_optimizer, offload_fsdp_param_and_grad, load_fsdp_optimizer, load_fsdp_param_and_grad
from verl.utils.import_utils import import_external_libs
from verl.utils.debug import log_gpu_memory_usage
import verl.utils.hdfs_io as hdfs_io
from verl.utils import hf_tokenizer
from ..trainer.ppo import core_algos
from verl.utils.py_functional import append_to_dict
from codetiming import Timer
from verl.utils.openvla_utils import update_auto_map , check_model_logic_mismatch
from peft import LoraConfig, PeftModel, get_peft_model, TaskType
import json
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN'))
def convert_to_regular_types(obj):
"""Convert Hydra configs and other special types to regular Python types."""
from omegaconf import ListConfig, DictConfig
if isinstance(obj, (ListConfig, DictConfig)):
return {k: convert_to_regular_types(v) for k, v in obj.items()} if isinstance(obj, DictConfig) else list(obj)
elif isinstance(obj, (list, tuple)):
return [convert_to_regular_types(x) for x in obj]
elif isinstance(obj, dict):
return {k: convert_to_regular_types(v) for k, v in obj.items()}
return obj
class RobActorRolloutRefWorker(Worker):
"""
This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy
or a hybrid engine based on the config.rollout
"""
def __init__(self, config: DictConfig, role: str):
super().__init__()
self.config = config
import torch.distributed
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
# build device mesh
world_size = torch.distributed.get_world_size()
from torch.distributed.device_mesh import init_device_mesh
# TODO(sgm): support FSDP hybrid shard for larger model
self.device_mesh = init_device_mesh('cuda', mesh_shape=(world_size,), mesh_dim_names=['fsdp'])
self._is_lora = self.config.model.get('lora_rank', 0) > 0
self.role = role
assert self.role in ['actor', 'rollout', 'ref', 'actor_rollout', 'actor_rollout_ref']
self._is_actor = self.role in ['actor', 'actor_rollout', 'actor_rollout_ref']
self._is_rollout = self.role in ['rollout', 'actor_rollout', 'actor_rollout_ref']
self._is_ref = self.role in ['ref', 'actor_rollout_ref']
self._is_offload_param = False
self._is_offload_grad = False
self._is_offload_optimizer = False
if self._is_actor:
self._is_offload_param = self.config.actor.fsdp_config.get('param_offload', False)
self._is_offload_grad = self.config.actor.fsdp_config.get('grad_offload', False)
self._is_offload_optimizer = self.config.actor.fsdp_config.get('optimizer_offload', False)
elif self._is_ref:
# TODO: it seems that manual offload is slowly than FSDP offload
self._is_offload_param = self.config.ref.fsdp_config.get('param_offload', False)
# normalize config
if self._is_actor:
self.config.actor.ppo_mini_batch_size //= self.device_mesh.shape[0]
self.config.actor.ppo_micro_batch_size //= self.device_mesh.shape[0]
if self._is_rollout:
self.config.rollout.log_prob_micro_batch_size //= self.device_mesh.shape[0]
if self._is_ref:
self.config.ref.log_prob_micro_batch_size //= self.device_mesh.shape[0]
def _build_model_optimizer(self,
model_path,
fsdp_config,
optim_config,
override_model_config,
enable_gradient_checkpointing=False,
trust_remote_code=False):
from verl.utils.model import print_model_size, update_model_config
from verl.utils.torch_dtypes import PrecisionType
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, \
CPUOffload
from torch import optim
log_gpu_memory_usage('Before init from HF AutoModel', logger=logger)
local_path = copy_local_path_from_hdfs(model_path)
#add oft
if self.config.model.vla == "openvla-oft":
from verl.utils.vla_utils.openvla_oft.configuration_prismatic import OpenVLAConfig
from verl.utils.vla_utils.openvla_oft.modeling_prismatic import OpenVLAForActionPrediction
from verl.utils.vla_utils.openvla_oft.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor
AutoConfig.register("openvla", OpenVLAConfig)
AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor)
AutoProcessor.register(OpenVLAConfig, PrismaticProcessor)
AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction)
if self.rank == 0:
update_auto_map(local_path)
check_model_logic_mismatch(local_path)
torch.distributed.barrier()
elif self.config.model.vla == "openvla":
from verl.utils.vla_utils.openvla.configuration_prismatic import OpenVLAConfig
from verl.utils.vla_utils.openvla.modeling_prismatic import OpenVLAForActionPrediction
from verl.utils.vla_utils.openvla.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor
AutoConfig.register("openvla", OpenVLAConfig)
AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor)
AutoProcessor.register(OpenVLAConfig, PrismaticProcessor)
AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction)
if self.rank == 0:
update_auto_map(local_path)
check_model_logic_mismatch(local_path)
torch.distributed.barrier()
#add end
# note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect
# TODO(zhangchi.usc1992): 1. support create from random initialized model. 2. Support init with FSDP directly
self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code, model = self.config.model.vla)
torch_dtype = fsdp_config.get('model_dtype', None)
if torch_dtype is None:
torch_dtype = torch.float32 if self._is_actor else torch.bfloat16
else:
torch_dtype = PrecisionType.to_dtype(torch_dtype)
# override model kwargs
actor_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)
if self.config.model.use_remove_padding:
from verl.models.registry import check_model_support_rmpad
check_model_support_rmpad(actor_model_config.model_type)
override_config_kwargs = {
'bos_token_id': self.tokenizer.bos_token_id,
'eos_token_id': self.tokenizer.eos_token_id,
'pad_token_id': self.tokenizer.pad_token_id,
}
override_config_kwargs.update(override_model_config)
update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs)
if self.rank == 0:
print(f'Model config after override: {actor_model_config}')
init_context = get_init_weight_context_manager(use_meta_tensor=not actor_model_config.tie_word_embeddings)
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
if self.config.model.vla == "openvla-oft":
actor_module = AutoModelForVision2Seq.from_pretrained(
pretrained_model_name_or_path=local_path,
torch_dtype=torch_dtype,
#attn_implementation="flash_attention_2",
config=actor_model_config,
trust_remote_code=True,
)
#oft add
actor_module.vision_backbone.set_num_images_in_input(self.config.actor.num_images_in_input)
dataset_statistics_path = os.path.join(local_path, "dataset_statistics.json")
if os.path.isfile(dataset_statistics_path):
with open(dataset_statistics_path, "r") as f:
norm_stats = json.load(f)
actor_module.norm_stats = norm_stats
else:
print(
"WARNING: No local dataset_statistics.json file found for current checkpoint.\n"
"You can ignore this if you are loading the base VLA (i.e. not fine-tuned) checkpoint."
"Otherwise, you may run into errors when trying to call `predict_action()` due to an absent `unnorm_key`."
)
elif self.config.model.vla == "openvla":
actor_module = AutoModelForVision2Seq.from_pretrained(
pretrained_model_name_or_path=local_path,
torch_dtype=torch_dtype,
attn_implementation="flash_attention_2",
config=actor_model_config,
trust_remote_code=True,
)
actor_module.to(torch_dtype)
if enable_gradient_checkpointing:
actor_module.gradient_checkpointing_enable()
# lora add
if self._is_lora:
print("Applying LoRA to actor module")
lora_config = {
#'task_type': TaskType.CAUSAL_LM,
'r': self.config.model.lora_rank,
'lora_alpha': self.config.model.lora_alpha,
"lora_dropout": 0 ,
'target_modules': convert_to_regular_types(self.config.model.target_modules),
'init_lora_weights': "gaussian"
}
actor_module = get_peft_model(actor_module, LoraConfig(**lora_config))
actor_module.print_trainable_parameters()
# lora end
torch.distributed.barrier()
if self.rank == 0:
print_model_size(actor_module)
log_gpu_memory_usage('After init from HF AutoModel', logger=logger)
# We wrap FSDP for rollout as well
mixed_precision_config = fsdp_config.get('mixed_precision', None)
if mixed_precision_config is not None:
param_dtype = PrecisionType.to_dtype(mixed_precision_config.get('param_dtype', 'bf16'))
reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get('reduce_dtype', 'fp32'))
buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get('buffer_dtype', 'fp32'))
else:
param_dtype = torch.bfloat16
reduce_dtype = torch.float32
buffer_dtype = torch.float32
mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype)
if self._is_ref:
mixed_precision = None
#oft add
auto_wrap_policy = get_fsdp_wrap_policy_vla(module=actor_module, config=fsdp_config.get('wrap_policy', None), is_lora=self.config.model.get('lora_rank', 0) > 0)
#oft add end
print(f'wrap_policy: {auto_wrap_policy}')
# TODO(sgm): support hybrid
if auto_wrap_policy is None:
sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
else:
sharding_strategy = ShardingStrategy.FULL_SHARD
# TODO: add transformer policy
actor_module_fsdp = FSDP(
actor_module,
param_init_fn=init_fn,
use_orig_params=False,
auto_wrap_policy=auto_wrap_policy,
device_id=torch.cuda.current_device(),
sharding_strategy=sharding_strategy, # zero3
mixed_precision=mixed_precision,
sync_module_states=True,
device_mesh=self.device_mesh)
log_gpu_memory_usage('After Actor FSDP init', logger=logger)
# TODO: add more optimizer args into config
if self._is_actor:
from verl.utils.torch_functional import get_constant_schedule_with_warmup
actor_optimizer = optim.AdamW(actor_module_fsdp.parameters(),
lr=optim_config.lr,
betas=optim_config.get('betas', (0.9, 0.999)),
weight_decay=optim_config.get('weight_decay', 1e-2))
total_steps = optim_config.get('total_training_steps', 0)
num_warmup_steps_ratio = optim_config.get('lr_warmup_steps_ratio', 0.)
num_warmup_steps = int(num_warmup_steps_ratio * total_steps)
print(f'Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}')
actor_lr_scheduler = get_constant_schedule_with_warmup(optimizer=actor_optimizer,
num_warmup_steps=num_warmup_steps)
else:
actor_optimizer = None
actor_lr_scheduler = None
log_gpu_memory_usage('After actor optimizer init', logger=logger)
return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config
def _build_rollout(self):
if self.config.rollout.name == 'hf':
from verl.workers.rollout import RobHFRollout
from verl.workers.hybrid_engine import BaseShardingManager
rollout = RobHFRollout(module=self.actor_module_fsdp, config=self.config.rollout)
sharding_manager = BaseShardingManager()
# TODO: a sharding manager that do nothing?
elif self.config.rollout.name == 'vllm':
raise ValueError
# from verl.workers.rollout.vllm_rollout import vLLMRollout
# from verl.workers.hybrid_engine import FSDPVLLMShardingManager
# log_gpu_memory_usage('Before building vllm rollout', logger=None)
# rollout = vLLMRollout(actor_module=self.actor_module_fsdp,
# config=self.config.rollout,
# tokenizer=self.tokenizer,
# model_hf_config=self.actor_model_config)
# log_gpu_memory_usage('After building vllm rollout', logger=None)
# if torch.distributed.get_world_size() == 1:
# self.config.rollout.load_format = 'dummy_hf'
# sharding_manager = FSDPVLLMShardingManager(module=self.actor_module_fsdp,
# inference_engine=rollout.inference_engine,
# model_config=self.actor_model_config,
# full_params='hf' in self.config.rollout.load_format)
# log_gpu_memory_usage('After building sharding manager', logger=None)
return rollout, sharding_manager
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
from verl.workers.actor import RobDataParallelPPOActor
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get('external_lib', None))
from omegaconf import OmegaConf
override_model_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create()))
if self._is_actor or self._is_rollout:
# we need the model for actor and rollout
if self._is_actor:
optim_config = self.config.actor.optim
fsdp_config = self.config.actor.fsdp_config
else:
optim_config = None
fsdp_config = OmegaConf.create()
self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = self._build_model_optimizer(
model_path=self.config.model.path,
fsdp_config=fsdp_config,
optim_config=optim_config,
override_model_config=override_model_config,
enable_gradient_checkpointing=self.config.model.get('enable_gradient_checkpointing', False),
trust_remote_code=True) #self.config.model.get('trust_remote_code', True)
# get the original unwrapped module
self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module
if self._is_offload_param:
# param is require during state_dict in sharding manager
offload_fsdp_grad(module=self.actor_module_fsdp)
log_gpu_memory_usage('After offload actor grad during init', logger=logger)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
log_gpu_memory_usage('After offload actor optimizer during init', logger=logger)
# load from checkpoint
if self._is_actor:
OmegaConf.set_struct(self.config.actor, True)
self.actor = RobDataParallelPPOActor(config=self.config.actor,
actor_module=self.actor_module_fsdp,
actor_optimizer=self.actor_optimizer)
if self._is_rollout:
self.rollout, self.sharding_manager = self._build_rollout()
if self._is_ref:
self.ref_module_fsdp = self._build_model_optimizer(model_path=self.config.model.path,
fsdp_config=self.config.ref.fsdp_config,
optim_config=None,
override_model_config=override_model_config,
trust_remote_code=True)[0] #self.config.model.get('trust_remote_code', False)
if self._is_offload_param:
offload_fsdp_param_and_grad(module=self.ref_module_fsdp, offload_grad=self._is_offload_grad)
OmegaConf.set_struct(self.config.ref, True)
self.ref_policy = RobDataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp)
torch.cuda.synchronize()
torch.distributed.barrier()
torch.cuda.empty_cache()
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_actor(self, data: DataProto):
#data = data.to('cuda')
assert self._is_actor
if self._is_offload_param:
load_fsdp_param_and_grad(module=self.actor_module_fsdp,
device_id=torch.cuda.current_device(),
load_grad=self._is_offload_grad)
if self._is_offload_optimizer:
load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=torch.cuda.current_device())
#data.batch = data.batch.cuda()
log_gpu_memory_usage('Before update policy', logger=logger)
metrics = self.actor.update_policy(data=data)
self.actor_lr_scheduler.step()
lr = self.actor_lr_scheduler.get_last_lr()[0]
metrics['actor/lr(1e-4)'] = lr * 1e4
log_gpu_memory_usage('After update policy', logger=logger)
# TODO: here, we should return all metrics
output = DataProto(meta_info={'metrics': metrics})
output = output.to('cpu')
if self._is_offload_param:
offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
torch.cuda.synchronize()
torch.distributed.barrier()
torch.cuda.empty_cache()
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_entropy(self, data: DataProto):
data = data.to('cuda')
assert self._is_actor
if self._is_offload_param:
load_fsdp_param_and_grad(module=self.actor_module_fsdp,
device_id=torch.cuda.current_device(),
load_grad=self._is_offload_grad)
data.batch = data.batch.cuda()
log_gpu_memory_usage('Before compute entropy', logger=logger)
metrics = self.actor.compute_entropy(bacth_data=data)
log_gpu_memory_usage('After compute entropy', logger=logger)
# TODO: here, we should return all metrics
output = DataProto(meta_info={'metrics': metrics})
output = output.to('cpu')
if self._is_offload_param:
offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
torch.cuda.synchronize()
torch.distributed.barrier()
torch.cuda.empty_cache()
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def generate_sequences(self, prompts):
prompts = prompts.to('cuda')
# set to False if it is validation
recompute_log_prob = prompts.meta_info.get('recompute_log_prob', True)
assert self._is_rollout
if self._is_offload_param:
load_fsdp_param_and_grad(module=self.actor_module_fsdp,
device_id=torch.cuda.current_device(),
load_grad=self._is_offload_grad)
prompts.batch = prompts.batch.cuda()
meta_info = {'eos_token_id': self.tokenizer.eos_token_id, 'pad_token_id': self.tokenizer.pad_token_id}
prompts.meta_info.update(meta_info)
#tmp_sample = prompts.meta_info.get('n_samples', -1)
# with Timer(name=f'gen seq will start, and the num samples are: {tmp_sample}', text="{name}: {seconds:.1f} seconds") as timer:
# print(f"gen seq will start, and the num samples are: {tmp_sample}")
with self.sharding_manager:
log_gpu_memory_usage('After entering sharding manager', logger=logger)
prompts = self.sharding_manager.preprocess_data(prompts)
output = self.rollout.generate_sequences(prompts=prompts)
log_gpu_memory_usage('After rollout generation', logger=logger)
output = self.sharding_manager.postprocess_data(output)
torch.cuda.synchronize()
# with Timer(name=f'gen seq end , old log will begin', text="{name}: {seconds:.1f} seconds") as timer:
# print("gen seq end , old log will begin")
if self._is_actor and recompute_log_prob:
# we should always recompute old_log_probs when it is HybridEngine
output.meta_info['micro_batch_size'] = self.config.rollout.log_prob_micro_batch_size
output.meta_info['temperature'] = self.config.rollout.temperature
output.meta_info['use_dynamic_bsz'] = self.config.rollout.log_prob_use_dynamic_bsz
output.meta_info['max_token_len'] = self.config.rollout.log_prob_max_token_len_per_gpu
output.meta_info['pad_token_id'] = self.tokenizer.pad_token_id
old_log_probs = self.actor.compute_log_prob(data=output)
output.batch['old_log_probs'] = old_log_probs
output = output.to('cpu')
if self._is_offload_param:
# NOTE(sgm): the grad is already in CPU, only offload param here
offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad)
# clear kv cache
torch.cuda.synchronize()
torch.distributed.barrier()
torch.cuda.empty_cache()
log_gpu_memory_usage('After recompute log prob', logger=logger)
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_ref_log_prob(self, data: DataProto):
assert self._is_ref
data = data.to('cuda')
if self._is_offload_param:
load_fsdp_param_and_grad(module=self.ref_module_fsdp,
device_id=torch.cuda.current_device(),
load_grad=self._is_offload_grad)
micro_batch_size = self.config.ref.log_prob_micro_batch_size
data.meta_info['micro_batch_size'] = micro_batch_size
data.meta_info['temperature'] = self.config.rollout.temperature
data.meta_info['max_token_len'] = self.config.ref.log_prob_max_token_len_per_gpu
data.meta_info['use_dynamic_bsz'] = self.config.ref.log_prob_use_dynamic_bsz
data.meta_info['pad_token_id'] = self.tokenizer.pad_token_id
output = self.ref_policy.compute_log_prob(data=data)
output = DataProto.from_dict(tensors={'ref_log_prob': output})
output = output.to('cpu')
if self._is_offload_param:
offload_fsdp_param_and_grad(module=self.ref_module_fsdp, offload_grad=self._is_offload_grad)
torch.cuda.synchronize()
torch.distributed.barrier()
torch.cuda.empty_cache()
return output
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(self, local_path, hdfs_path=None):
assert self._is_actor
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from peft import PeftModel
import transformers
if self._is_offload_param:
load_fsdp_param_and_grad(module=self.actor_module_fsdp,
device_id=torch.cuda.current_device(),
load_grad=self._is_offload_grad)
#lora add
if self._is_lora and isinstance(self.actor_module, PeftModel):
if dist.get_rank() == 0:
os.makedirs(local_path, exist_ok=True)
lora_save_path = os.path.join(local_path, "lora_adapter")
if isinstance(self.actor_module_fsdp, FSDP):
with FSDP.summon_full_params(self.actor_module_fsdp, writeback=False, offload_to_cpu=True):
if dist.get_rank() == 0:
from typing import OrderedDict
lora_params = OrderedDict()
model = self.actor_module_fsdp._fsdp_wrapped_module.base_model.model
for name, param in model.named_parameters():
if ".lora_" in name:
name = "base_model.model." + name.replace("._fsdp_wrapped_module.", ".")
lora_params[name] = param
self.actor_module_fsdp.save_pretrained(
lora_save_path,
state_dict=lora_params,
safe_serialization=True
)
else:
self.actor_module.save_pretrained(lora_save_path, safe_serialization=True)
dist.barrier()
if dist.get_rank() == 0:
print(f"[rank-{self.rank}]: Saved LoRA adapter to: {lora_save_path}")
# save total model
base_vla = AutoModelForVision2Seq.from_pretrained(
self.config.model.path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True, device_map="cpu"
)
merged_vla = PeftModel.from_pretrained(base_vla, lora_save_path)
merged_vla = merged_vla.merge_and_unload()
if dist.get_rank() == 0:
merged_vla.save_pretrained(local_path)
print(f"Saved merged model at: {local_path}")
# Wait for merged model to be saved
dist.barrier()
# TODO: support DCP and save sharded checkpoints
else:
import torch.distributed
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType, FullStateDictConfig
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(self.actor.actor_module, StateDictType.FULL_STATE_DICT, cfg):
state_dict = self.actor.actor_module.state_dict()
if self.rank == 0:
print(f'Saving actor checkpoint to {local_path}')
os.makedirs(local_path, exist_ok=True)
self.actor_module.save_pretrained(local_path, state_dict=state_dict)
self.tokenizer.save_pretrained(local_path)
if hdfs_path is not None:
print(f'Uploading actor checkpoint to {hdfs_path}')
hdfs_io.makedirs(hdfs_path, exist_ok=True)
hdfs_io.copy(src=local_path, dst=hdfs_path)
torch.distributed.barrier()
if self._is_offload_param:
offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad)
class ActorRolloutRefWorker(Worker):
"""
This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy
or a hybrid engine based on the config.rollout
"""
def __init__(self, config: DictConfig, role: str):
super().__init__()
self.config = config
import torch.distributed
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
# build device mesh
world_size = torch.distributed.get_world_size()
from torch.distributed.device_mesh import init_device_mesh
# TODO(sgm): support FSDP hybrid shard for larger model
self.device_mesh = init_device_mesh('cuda', mesh_shape=(world_size,), mesh_dim_names=['fsdp'])
self.role = role
assert self.role in ['actor', 'rollout', 'ref', 'actor_rollout', 'actor_rollout_ref']
self._is_actor = self.role in ['actor', 'actor_rollout', 'actor_rollout_ref']
self._is_rollout = self.role in ['rollout', 'actor_rollout', 'actor_rollout_ref']
self._is_ref = self.role in ['ref', 'actor_rollout_ref']
self._is_offload_param = False
self._is_offload_grad = False
self._is_offload_optimizer = False
if self._is_actor:
self._is_offload_param = self.config.actor.fsdp_config.get('param_offload', False)
self._is_offload_grad = self.config.actor.fsdp_config.get('grad_offload', False)
self._is_offload_optimizer = self.config.actor.fsdp_config.get('optimizer_offload', False)
elif self._is_ref:
# TODO: it seems that manual offload is slowly than FSDP offload
self._is_offload_param = self.config.ref.fsdp_config.get('param_offload', False)
# normalize config
if self._is_actor:
self.config.actor.ppo_mini_batch_size //= self.device_mesh.shape[0]
self.config.actor.ppo_micro_batch_size //= self.device_mesh.shape[0]
if self._is_rollout:
self.config.rollout.log_prob_micro_batch_size //= self.device_mesh.shape[0]
if self._is_ref:
self.config.ref.log_prob_micro_batch_size //= self.device_mesh.shape[0]
def _build_model_optimizer(self,
model_path,
fsdp_config,
optim_config,
override_model_config,
enable_gradient_checkpointing=False,
trust_remote_code=False):
from verl.utils.model import print_model_size, update_model_config
from verl.utils.torch_dtypes import PrecisionType
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, \
CPUOffload
from torch import optim
log_gpu_memory_usage('Before init from HF AutoModel', logger=logger)
local_path = copy_local_path_from_hdfs(model_path)
# note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect
# TODO(zhangchi.usc1992): 1. support create from random initialized model. 2. Support init with FSDP directly
self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
torch_dtype = fsdp_config.get('model_dtype', None)
if torch_dtype is None:
torch_dtype = torch.float32 if self._is_actor else torch.bfloat16
else:
torch_dtype = PrecisionType.to_dtype(torch_dtype)
# override model kwargs
actor_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)
if self.config.model.use_remove_padding:
from verl.models.registry import check_model_support_rmpad
check_model_support_rmpad(actor_model_config.model_type)
override_config_kwargs = {
'bos_token_id': self.tokenizer.bos_token_id,
'eos_token_id': self.tokenizer.eos_token_id,
'pad_token_id': self.tokenizer.pad_token_id,
}
override_config_kwargs.update(override_model_config)
update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs)
if self.rank == 0:
print(f'Model config after override: {actor_model_config}')
# NOTE(fix me): tie_word_embedding causes meta_tensor init to hang
init_context = get_init_weight_context_manager(use_meta_tensor=not actor_model_config.tie_word_embeddings)
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
from liger_kernel.transformers import AutoLigerKernelForCausalLM
actor_module = AutoLigerKernelForCausalLM.from_pretrained(pretrained_model_name_or_path=local_path,
torch_dtype=torch_dtype,
config=actor_model_config,
attn_implementation='flash_attention_2',
trust_remote_code=trust_remote_code)
# some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2
actor_module.to(torch_dtype)
if enable_gradient_checkpointing:
actor_module.gradient_checkpointing_enable()
torch.distributed.barrier()
if self.rank == 0:
print_model_size(actor_module)
log_gpu_memory_usage('After init from HF AutoModel', logger=logger)
# We wrap FSDP for rollout as well
mixed_precision_config = fsdp_config.get('mixed_precision', None)
if mixed_precision_config is not None:
param_dtype = PrecisionType.to_dtype(mixed_precision_config.get('param_dtype', 'bf16'))
reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get('reduce_dtype', 'fp32'))
buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get('buffer_dtype', 'fp32'))
else:
param_dtype = torch.bfloat16
reduce_dtype = torch.float32
buffer_dtype = torch.float32
mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype)
if self._is_ref:
mixed_precision = None
auto_wrap_policy = get_fsdp_wrap_policy(module=actor_module, config=fsdp_config.get('wrap_policy', None))
if self._is_rollout and self.config.rollout.name == 'hf':
# TODO(zhangchi.usc1992, shengguangming) fix me. Current, auto_wrap_policy causes HFRollout to hang in Gemma
auto_wrap_policy = None
print(f'wrap_policy: {auto_wrap_policy}')
# TODO(sgm): support hybrid
if auto_wrap_policy is None:
sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
else:
sharding_strategy = ShardingStrategy.FULL_SHARD
# TODO: add transformer policy
actor_module_fsdp = FSDP(
actor_module,
param_init_fn=init_fn,
use_orig_params=False,
auto_wrap_policy=auto_wrap_policy,
device_id=torch.cuda.current_device(),
sharding_strategy=sharding_strategy, # zero3
mixed_precision=mixed_precision,
sync_module_states=True,
device_mesh=self.device_mesh)
log_gpu_memory_usage('After Actor FSDP init', logger=logger)
# TODO: add more optimizer args into config
if self._is_actor:
from verl.utils.torch_functional import get_constant_schedule_with_warmup
actor_optimizer = optim.AdamW(actor_module_fsdp.parameters(),
lr=optim_config.lr,
betas=optim_config.get('betas', (0.9, 0.999)),
weight_decay=optim_config.get('weight_decay', 1e-2))
total_steps = optim_config.get('total_training_steps', 0)
num_warmup_steps_ratio = optim_config.get('lr_warmup_steps_ratio', 0.)
num_warmup_steps = int(num_warmup_steps_ratio * total_steps)
print(f'Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}')
actor_lr_scheduler = get_constant_schedule_with_warmup(optimizer=actor_optimizer,
num_warmup_steps=num_warmup_steps)
else:
actor_optimizer = None
actor_lr_scheduler = None
log_gpu_memory_usage('After actor optimizer init', logger=logger)
return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config
def _build_rollout(self):
if self.config.rollout.name == 'hf':
from verl.workers.rollout import HFRollout
from verl.workers.hybrid_engine import BaseShardingManager
rollout = HFRollout(module=self.actor_module_fsdp, config=self.config.rollout)
sharding_manager = BaseShardingManager()
# TODO: a sharding manager that do nothing?
elif self.config.rollout.name == 'vllm':
from verl.workers.rollout.vllm_rollout import vLLMRollout
from verl.workers.hybrid_engine import FSDPVLLMShardingManager
log_gpu_memory_usage('Before building vllm rollout', logger=None)
rollout = vLLMRollout(actor_module=self.actor_module_fsdp,
config=self.config.rollout,
tokenizer=self.tokenizer,
model_hf_config=self.actor_model_config)
log_gpu_memory_usage('After building vllm rollout', logger=None)
if torch.distributed.get_world_size() == 1:
self.config.rollout.load_format = 'dummy_hf'
sharding_manager = FSDPVLLMShardingManager(module=self.actor_module_fsdp,
inference_engine=rollout.inference_engine,
model_config=self.actor_model_config,
full_params='hf' in self.config.rollout.load_format)
log_gpu_memory_usage('After building sharding manager', logger=None)
return rollout, sharding_manager
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
from verl.workers.actor import DataParallelPPOActor
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get('external_lib', None))
from omegaconf import OmegaConf
override_model_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create()))
if self._is_actor or self._is_rollout:
# we need the model for actor and rollout
if self._is_actor:
optim_config = self.config.actor.optim
fsdp_config = self.config.actor.fsdp_config
else:
optim_config = None
fsdp_config = OmegaConf.create()
self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = self._build_model_optimizer(
model_path=self.config.model.path,
fsdp_config=fsdp_config,
optim_config=optim_config,
override_model_config=override_model_config,
enable_gradient_checkpointing=self.config.model.get('enable_gradient_checkpointing', False),
trust_remote_code=self.config.model.get('trust_remote_code', False))
# get the original unwrapped module
self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module
if self._is_offload_param:
# param is require during state_dict in sharding manager
offload_fsdp_grad(module=self.actor_module_fsdp)
log_gpu_memory_usage('After offload actor grad during init', logger=logger)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
log_gpu_memory_usage('After offload actor optimizer during init', logger=logger)
# load from checkpoint
if self._is_actor:
OmegaConf.set_struct(self.config.actor, True)
self.actor = DataParallelPPOActor(config=self.config.actor,
actor_module=self.actor_module_fsdp,
actor_optimizer=self.actor_optimizer)
if self._is_rollout:
self.rollout, self.sharding_manager = self._build_rollout()
if self._is_ref:
self.ref_module_fsdp = self._build_model_optimizer(model_path=self.config.model.path,
fsdp_config=self.config.ref.fsdp_config,
optim_config=None,
override_model_config=override_model_config,
trust_remote_code=self.config.model.get(
'trust_remote_code', False))[0]
if self._is_offload_param:
offload_fsdp_param_and_grad(module=self.ref_module_fsdp, offload_grad=self._is_offload_grad)
OmegaConf.set_struct(self.config.ref, True)
self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp)
torch.cuda.synchronize()
torch.distributed.barrier()
torch.cuda.empty_cache()
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_actor(self, data: DataProto):
data = data.to('cuda')
assert self._is_actor
if self._is_offload_param:
load_fsdp_param_and_grad(module=self.actor_module_fsdp,
device_id=torch.cuda.current_device(),
load_grad=self._is_offload_grad)
if self._is_offload_optimizer:
load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=torch.cuda.current_device())
data.batch = data.batch.cuda()
log_gpu_memory_usage('Before update policy', logger=logger)
metrics = self.actor.update_policy(data=data)
self.actor_lr_scheduler.step()
lr = self.actor_lr_scheduler.get_last_lr()[0]
metrics['actor/lr(1e-4)'] = lr * 1e4
log_gpu_memory_usage('After update policy', logger=logger)
# TODO: here, we should return all metrics
output = DataProto(meta_info={'metrics': metrics})
output = output.to('cpu')
if self._is_offload_param:
offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
torch.cuda.synchronize()
torch.distributed.barrier()
torch.cuda.empty_cache()
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_entropy(self, data: DataProto):
data = data.to('cuda')
assert self._is_actor
if self._is_offload_param:
load_fsdp_param_and_grad(module=self.actor_module_fsdp,
device_id=torch.cuda.current_device(),
load_grad=self._is_offload_grad)
data.batch = data.batch.cuda()
log_gpu_memory_usage('Before compute entropy', logger=logger)
metrics = self.actor.compute_entropy(bacth_data=data)
log_gpu_memory_usage('After compute entropy', logger=logger)
# TODO: here, we should return all metrics
output = DataProto(meta_info={'metrics': metrics})
output = output.to('cpu')
if self._is_offload_param:
offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
torch.cuda.synchronize()
torch.distributed.barrier()
torch.cuda.empty_cache()
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def generate_sequences(self, prompts: DataProto):
prompts = prompts.to('cuda')
# set to False if it is validation
recompute_log_prob = prompts.meta_info.get('recompute_log_prob', True)
assert self._is_rollout
if self._is_offload_param:
load_fsdp_param_and_grad(module=self.actor_module_fsdp,
device_id=torch.cuda.current_device(),
load_grad=self._is_offload_grad)
prompts.batch = prompts.batch.cuda()
meta_info = {'eos_token_id': self.tokenizer.eos_token_id, 'pad_token_id': self.tokenizer.pad_token_id}
prompts.meta_info.update(meta_info)
with self.sharding_manager:
log_gpu_memory_usage('After entering sharding manager', logger=logger)
prompts = self.sharding_manager.preprocess_data(prompts)
output = self.rollout.generate_sequences(prompts=prompts)
log_gpu_memory_usage('After rollout generation', logger=logger)
output = self.sharding_manager.postprocess_data(output)
torch.cuda.synchronize()
if self._is_actor and recompute_log_prob:
# we should always recompute old_log_probs when it is HybridEngine
output.meta_info['micro_batch_size'] = self.config.rollout.log_prob_micro_batch_size
output.meta_info['temperature'] = self.config.rollout.temperature
output.meta_info['use_dynamic_bsz'] = self.config.rollout.log_prob_use_dynamic_bsz
output.meta_info['max_token_len'] = self.config.rollout.log_prob_max_token_len_per_gpu
old_log_probs = self.actor.compute_log_prob(data=output)
output.batch['old_log_probs'] = old_log_probs
output = output.to('cpu')
if self._is_offload_param:
# NOTE(sgm): the grad is already in CPU, only offload param here
offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad)
# clear kv cache
torch.cuda.synchronize()
torch.distributed.barrier()
torch.cuda.empty_cache()
log_gpu_memory_usage('After recompute log prob', logger=logger)
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_ref_log_prob(self, data: DataProto):
assert self._is_ref
data = data.to('cuda')
if self._is_offload_param:
load_fsdp_param_and_grad(module=self.ref_module_fsdp,
device_id=torch.cuda.current_device(),
load_grad=self._is_offload_grad)
micro_batch_size = self.config.ref.log_prob_micro_batch_size
data.meta_info['micro_batch_size'] = micro_batch_size
data.meta_info['temperature'] = self.config.rollout.temperature
data.meta_info['max_token_len'] = self.config.ref.log_prob_max_token_len_per_gpu
data.meta_info['use_dynamic_bsz'] = self.config.ref.log_prob_use_dynamic_bsz
output = self.ref_policy.compute_log_prob(data=data)
output = DataProto.from_dict(tensors={'ref_log_prob': output})
output = output.to('cpu')
if self._is_offload_param:
offload_fsdp_param_and_grad(module=self.ref_module_fsdp, offload_grad=self._is_offload_grad)
torch.cuda.synchronize()
torch.distributed.barrier()
torch.cuda.empty_cache()
return output
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(self, local_path, hdfs_path=None):
assert self._is_actor
import torch
if self._is_offload_param:
load_fsdp_param_and_grad(module=self.actor_module_fsdp,
device_id=torch.cuda.current_device(),
load_grad=self._is_offload_grad)
# TODO: support DCP and save sharded checkpoints
import torch.distributed
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType, FullStateDictConfig
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(self.actor.actor_module, StateDictType.FULL_STATE_DICT, cfg):
state_dict = self.actor.actor_module.state_dict()
if self.rank == 0:
print(f'Saving actor checkpoint to {local_path}')
os.makedirs(local_path, exist_ok=True)
self.actor_module.save_pretrained(local_path, state_dict=state_dict)
self.tokenizer.save_pretrained(local_path)
if hdfs_path is not None:
print(f'Uploading actor checkpoint to {hdfs_path}')
hdfs_io.makedirs(hdfs_path, exist_ok=True)
hdfs_io.copy(src=local_path, dst=hdfs_path)
torch.distributed.barrier()
if self._is_offload_param:
offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad)
class CriticWorker(Worker):
def __init__(self, config):
super().__init__()
import torch.distributed
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
self.config = config
self._is_offload_param = self.config.model.fsdp_config.param_offload
self._is_offload_grad = self.config.model.fsdp_config.grad_offload
self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload
# normalize config
self.config.ppo_mini_batch_size //= torch.distributed.get_world_size()
self.config.ppo_micro_batch_size //= torch.distributed.get_world_size()
def _build_critic_model_optimizer(self, config):
# the following line is necessary
from verl.utils.model import LambdaLayer, print_model_size, squeeze
from verl.utils.torch_dtypes import PrecisionType
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, \
CPUOffload
from torch import optim
local_path = copy_local_path_from_hdfs(config.model.path)
# note that the tokenizer between actor and critic may be different. So override tokenizer info with actor info
# using random initialized model from any architecture. May not be the same as Actor.
# TODO: support loading critic weights from RM. Support using AutoModelForTokenClassification
from transformers import AutoTokenizer
tokenizer_path = copy_local_path_from_hdfs(config.model.tokenizer_path)
self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get('trust_remote_code', False))
from omegaconf import OmegaConf
override_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create()))
override_config_kwargs = {
'bos_token_id': self.tokenizer.bos_token_id,
'eos_token_id': self.tokenizer.eos_token_id,
'pad_token_id': self.tokenizer.pad_token_id,
}
override_config_kwargs.update(override_config)
if self.rank == 0:
print(f'Critic overriding config {override_config_kwargs}')
torch_dtype = self.config.model.fsdp_config.get('model_dtype', 'fp32')
torch_dtype = PrecisionType.to_dtype(torch_dtype)
from transformers import AutoConfig, AutoModelForCausalLM
from torch import nn
trust_remote_code = False
critic_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)
init_context = get_init_weight_context_manager()
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
critic_module = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=local_path,
torch_dtype=torch_dtype,
config=critic_model_config,
attn_implementation='flash_attention_2',
trust_remote_code=trust_remote_code)
critic_module.lm_head = nn.Sequential(nn.Linear(critic_model_config.hidden_size, 1, dtype=torch_dtype),
LambdaLayer(fn=squeeze))
# some parameters may not in torch_dtype
critic_module.to(torch_dtype)
if config.model.get('enable_gradient_checkpointing', False):
critic_module.gradient_checkpointing_enable()
if self.rank == 0:
print_model_size(critic_module)
fsdp_config = self.config.model.fsdp_config
mixed_precision_config = fsdp_config.get('mixed_precision', None)
if mixed_precision_config is not None:
param_dtype = PrecisionType.to_dtype(mixed_precision_config.get('param_dtype', 'bf16'))
reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get('reduce_dtype', 'fp32'))
buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get('buffer_dtype', 'fp32'))
else:
param_dtype = torch.bfloat16
reduce_dtype = torch.float32
buffer_dtype = torch.float32
mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype)
auto_wrap_policy = get_fsdp_wrap_policy(module=critic_module, config=self.config.model.fsdp_config.wrap_policy)
log_gpu_memory_usage('Before critic FSDP', logger=None)
critic_module = FSDP(critic_module,
param_init_fn=init_fn,
use_orig_params=False,
auto_wrap_policy=auto_wrap_policy,
device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mixed_precision,
sync_module_states=True)
log_gpu_memory_usage('After critic FSDP', logger=None)
critic_optimizer = optim.AdamW(critic_module.parameters(),
lr=config.optim.lr,
betas=config.optim.get('betas', (0.9, 0.999)),
weight_decay=config.optim.get('weight_decay', 1e-2))
total_steps = config.optim.get('total_training_steps', 0)
num_warmup_steps_ratio = config.optim.get('lr_warmup_steps_ratio', 0.)
num_warmup_steps = int(num_warmup_steps_ratio * total_steps)
print(f'Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}')
from verl.utils.torch_functional import get_constant_schedule_with_warmup
critic_lr_scheduler = get_constant_schedule_with_warmup(optimizer=critic_optimizer,
num_warmup_steps=num_warmup_steps)
return critic_module, critic_optimizer, critic_lr_scheduler
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get('external_lib', None))
from verl.workers.critic import DataParallelPPOCritic
self.critic_module, self.critic_optimizer, self.critic_lr_scheduler = self._build_critic_model_optimizer(
self.config)
if self._is_offload_param:
offload_fsdp_param_and_grad(module=self.critic_module, offload_grad=self._is_offload_grad)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.critic_optimizer)
self.critic = DataParallelPPOCritic(config=self.config,
critic_module=self.critic_module,
critic_optimizer=self.critic_optimizer)
torch.cuda.empty_cache()
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_values(self, data: DataProto):
data = data.to('cuda')
if self._is_offload_param:
load_fsdp_param_and_grad(module=self.critic_module,
device_id=torch.cuda.current_device(),
load_grad=self._is_offload_grad)
micro_batch_size = self.config.ppo_micro_batch_size
data.meta_info['micro_batch_size'] = micro_batch_size
data.meta_info['max_token_len'] = self.config.forward_max_token_len_per_gpu
data.meta_info['use_dynamic_bsz'] = self.config.use_dynamic_bsz
values = self.critic.compute_values(data=data)
output = DataProto.from_dict(tensors={'values': values})
output = output.to('cpu')
if self._is_offload_param:
offload_fsdp_param_and_grad(module=self.critic_module, offload_grad=self._is_offload_grad)
torch.cuda.empty_cache()
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_critic(self, data: DataProto):
data = data.to('cuda')
if self._is_offload_param:
load_fsdp_param_and_grad(module=self.critic_module,
device_id=torch.cuda.current_device(),
load_grad=self._is_offload_grad)
if self._is_offload_optimizer:
load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=torch.cuda.current_device())
metrics = self.critic.update_critic(data=data)
self.critic_lr_scheduler.step()
lr = self.critic_lr_scheduler.get_last_lr()[0]
metrics['critic/lr(1e-4)'] = lr * 1e4
output = DataProto(batch=None, meta_info={'metrics': metrics})
if self._is_offload_param:
offload_fsdp_param_and_grad(module=self.critic_module, offload_grad=self._is_offload_grad)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.critic_optimizer)
torch.cuda.empty_cache()
output = output.to('cpu')
return output
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(self, local_path, hdfs_path=None):
import torch
if self._is_offload_param:
load_fsdp_param_and_grad(module=self.critic_module,
device_id=torch.cuda.current_device(),
load_grad=self._is_offload_grad)
# TODO: support DCP and save sharded checkpoints
import torch.distributed
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType, FullStateDictConfig
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(self.critic_module, StateDictType.FULL_STATE_DICT, cfg):
state_dict = self.critic_module.state_dict()
if self.rank == 0:
print(f'Saving critic checkpoint to {local_path}')
os.makedirs(local_path, exist_ok=True)
self.critic_module._fsdp_wrapped_module.save_pretrained(local_path, state_dict=state_dict)
self.tokenizer.save_pretrained(local_path)
if hdfs_path is not None:
print(f'Uploading critic checkpoint to {hdfs_path}')
hdfs_io.makedirs(hdfs_path, exist_ok=True)
hdfs_io.copy(src=local_path, dst=hdfs_path)
torch.distributed.barrier()
if self._is_offload_param:
offload_fsdp_param_and_grad(module=self.critic_module, offload_grad=self._is_offload_grad)
class RewardModelWorker(Worker):
"""
Note that we only implement the reward model that is subclass of AutoModelForSequenceClassification.
"""
def __init__(self, config):
super().__init__()
import torch.distributed
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
self.config = config
self.config.micro_batch_size //= torch.distributed.get_world_size()
def _build_model(self, config):
# the following line is necessary
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, CPUOffload
# download the checkpoint from hdfs
local_path = copy_local_path_from_hdfs(config.model.path)
if self.config.model.input_tokenizer is None:
self._do_switch_chat_template = False
else:
self._do_switch_chat_template = True
input_tokenizer_local_path = copy_local_path_from_hdfs(config.model.input_tokenizer)
self.input_tokenizer = hf_tokenizer(input_tokenizer_local_path,
trust_remote_code=config.model.get('trust_remote_code', False))
self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.model.get('trust_remote_code', False))
trust_remote_code = config.model.get('trust_remote_code', False)
model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)
# note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect
init_context = get_init_weight_context_manager(use_meta_tensor=not model_config.tie_word_embeddings)
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
reward_module = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path=local_path,
torch_dtype=torch.bfloat16,
attn_implementation='flash_attention_2',
trust_remote_code=trust_remote_code)
reward_module.to(torch.bfloat16)
auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config)
reward_module = FSDP(
reward_module,
param_init_fn=init_fn,
use_orig_params=False,
auto_wrap_policy=auto_wrap_policy,
device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.FULL_SHARD, # zero3
sync_module_states=True,
cpu_offload=CPUOffload(offload_params=self.config.model.fsdp_config.param_offload))
return reward_module
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get('external_lib', None))
self.reward_module = self._build_model(config=self.config)
torch.cuda.empty_cache()
def _forward_micro_batch(self, micro_batch):
with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16):
output = self.reward_module(input_ids=micro_batch['input_ids'],
attention_mask=micro_batch['attention_mask'],
position_ids=micro_batch['position_ids'])
rm_score = output.logits # (batch_size,)
rm_score = rm_score.squeeze(-1)
return rm_score
def _expand_to_token_level(self, data: DataProto, scores: torch.Tensor):
batch_size = data.batch.batch_size[0]
# expand as token_level_reward
attention_mask = data.batch['attention_mask']
position_ids = data.batch['position_ids']
response_length = data.batch['responses'].shape[-1]
eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,)
token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype) # (bsz, seqlen)
token_level_scores[torch.arange(batch_size), eos_mask_idx] = scores
# select the response part
token_level_scores = token_level_scores[:, -response_length:]
return token_level_scores
def _switch_chat_template(self, data: DataProto):
src_max_length = data.batch['attention_mask'].shape[-1]
src_tokenizer = self.input_tokenizer
target_tokenizer = self.tokenizer
rm_input_ids = []
rm_attention_mask = []
for i in range(data.batch.batch_size[0]):
# extract raw prompt
chat: list = data.non_tensor_batch['raw_prompt'][i].tolist()
# extract response
response_ids = data.batch['responses'][i]
response_length = response_ids.shape[-1]
valid_response_length = data.batch['attention_mask'][i][-response_length:].sum()
valid_response_ids = response_ids[:valid_response_length]
# decode
response = src_tokenizer.decode(valid_response_ids)
# remove bos and eos
response = response.replace(src_tokenizer.eos_token, '')
chat.append({'role': 'assistant', 'content': response})
prompt_with_chat_template = target_tokenizer.apply_chat_template(chat,
add_generation_prompt=False,
tokenize=False)
if self.rank == 0 and i == 0:
# for debugging purpose
print(f'Switch template. chat: {prompt_with_chat_template}')
# the maximum length is actually determined by the reward model itself
max_length = self.config.get('max_length', src_max_length)
if max_length is None:
max_length = src_max_length
input_ids, attention_mask = verl_F.tokenize_and_postprocess_data(
prompt=prompt_with_chat_template,
tokenizer=target_tokenizer,
max_length=max_length,
pad_token_id=target_tokenizer.pad_token_id,
left_pad=False, # right padding
truncation=self.config.get('truncation', 'right')) # truncate from the right
rm_input_ids.append(input_ids)
rm_attention_mask.append(attention_mask)
rm_input_ids = torch.cat(rm_input_ids, dim=0)
rm_attention_mask = torch.cat(rm_attention_mask, dim=0)
rm_position_ids = compute_position_id_with_mask(rm_attention_mask)
rm_inputs = {'input_ids': rm_input_ids, 'attention_mask': rm_attention_mask, 'position_ids': rm_position_ids}
return DataProto.from_dict(rm_inputs)
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_rm_score(self, data: DataProto):
data = data.to('cuda')
if self._do_switch_chat_template:
rm_data = self._switch_chat_template(data)
rm_data.batch = rm_data.batch.cuda()
micro_batches = rm_data.batch.split(self.config.micro_batch_size)
output = []
for micro_batch in micro_batches:
rm_score = self._forward_micro_batch(micro_batch)
output.append(rm_score)
scores = torch.cat(output, dim=0) # (batch_size)
token_level_scores = self._expand_to_token_level(data, scores)
# Note that this is only the scores, may not be the final rewards used to train RL
output = DataProto.from_dict(tensors={'rm_scores': token_level_scores})
output = output.to('cpu')
torch.cuda.empty_cache()
return output
class PRIMERewardModelWorker(Worker):
"""
PRIME reward model.
Can update itself whenever compute_rm_score is called.
"""
def __init__(self, config):
super().__init__()
import torch.distributed
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
self.config = config
world_size = torch.distributed.get_world_size()
self.config.mini_batch_size //= world_size
self.config.micro_batch_size //= world_size
# build device mesh
from torch.distributed.device_mesh import init_device_mesh
# TODO(sgm): support FSDP hybrid shard for larger model
self.device_mesh = init_device_mesh('cuda', mesh_shape=(world_size,), mesh_dim_names=['fsdp'])
self._is_offload_param = self.config.prime_model.fsdp_config.get('param_offload', False)
self._is_offload_grad = self.config.prime_model.fsdp_config.get('grad_offload', False)
self._is_offload_optimizer = self.config.prime_model.fsdp_config.get('optimizer_offload', False)
def _build_model_optimizer(self, config, enable_gradient_checkpointing=False):
# the following line is necessary
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, CPUOffload
# download the checkpoint from hdfs
local_path = copy_local_path_from_hdfs(config.prime_model.path)
if self.config.prime_model.input_tokenizer is None:
self._do_switch_chat_template = False
else:
self._do_switch_chat_template = True
input_tokenizer_local_path = copy_local_path_from_hdfs(config.prime_model.input_tokenizer)
self.input_tokenizer = hf_tokenizer(input_tokenizer_local_path,
trust_remote_code=config.prime_model.get('trust_remote_code', False))
self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.prime_model.get('trust_remote_code', False))
trust_remote_code = config.prime_model.get('trust_remote_code', False)
model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)
# note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect
if config.prime_model.use_remove_padding:
from verl.models.registry import check_model_support_rmpad
check_model_support_rmpad(model_config.model_type)
init_context = get_init_weight_context_manager(use_meta_tensor=not model_config.tie_word_embeddings)
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
from liger_kernel.transformers import AutoLigerKernelForCausalLM
reward_module = AutoLigerKernelForCausalLM.from_pretrained(pretrained_model_name_or_path=local_path,
torch_dtype=torch.float32,
attn_implementation='flash_attention_2',
trust_remote_code=trust_remote_code)
reward_module.to(torch.float32)
if enable_gradient_checkpointing:
reward_module.gradient_checkpointing_enable()
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision
mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32,
buffer_dtype=torch.float32)
if config.prime_model.get('enable_gradient_checkpointing', False):
reward_module.gradient_checkpointing_enable()
if config.prime_model.get("ref_type", 'freeze') == 'freeze':
reference_module = AutoLigerKernelForCausalLM.from_pretrained(
pretrained_model_name_or_path=copy_local_path_from_hdfs(config.prime_model.ref_path),
torch_dtype=torch.bfloat16,
attn_implementation='flash_attention_2',
trust_remote_code=trust_remote_code)
reference_module.to(torch.bfloat16)
for param in reference_module.parameters():
param.requires_grad = False
else:
reference_module = None
auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.prime_model.fsdp_config)
reward_module = FSDP(
reward_module,
param_init_fn=init_fn,
use_orig_params=False,
auto_wrap_policy=auto_wrap_policy,
device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.FULL_SHARD, # zero3
mixed_precision=mixed_precision,
device_mesh=self.device_mesh,
sync_module_states=True)
auto_wrap_policy = get_fsdp_wrap_policy(module=reference_module, config=self.config.prime_model.fsdp_config)
if reference_module is not None:
reference_module = FSDP(
reference_module,
param_init_fn=init_fn,
use_orig_params=False,
auto_wrap_policy=auto_wrap_policy,
device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.FULL_SHARD, # zero3
device_mesh=self.device_mesh,
sync_module_states=True)
self.update_dpo_type = self.config.prime_model.get('update', 'none')
if self.update_dpo_type in ['before', 'after']:
from torch import optim
self.reward_optimizer = optim.AdamW(reward_module.parameters(),
lr=config.prime_model.optim.lr,
betas=config.prime_model.optim.get('betas', (0.9, 0.999)),
weight_decay=config.prime_model.optim.get('weight_decay', 1e-2))
total_steps = config.prime_model.optim.get('total_training_steps', 0)
num_warmup_steps_ratio = config.prime_model.optim.get('lr_warmup_steps_ratio', 0.)
num_warmup_steps = int(num_warmup_steps_ratio * total_steps)
print(f'Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}')
from verl.utils.torch_functional import get_constant_schedule_with_warmup
self.reward_lr_scheduler = get_constant_schedule_with_warmup(optimizer=self.reward_optimizer,
num_warmup_steps=num_warmup_steps)
# fsdp offload configurations
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.reward_optimizer)
if self._is_offload_param:
offload_fsdp_param_and_grad(module=reward_module, offload_grad=self._is_offload_grad)
if reference_module is not None:
offload_fsdp_param_and_grad(module=reference_module, offload_grad=self._is_offload_grad)
return reward_module, reference_module
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
from verl.workers.actor import DataParallelPRIME
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.prime_model.get('external_lib', None))
self.reward_module, self.reference_module = self._build_model_optimizer(config=self.config, enable_gradient_checkpointing=self.config.prime_model.get('enable_gradient_checkpointing', False))
self.prm = DataParallelPRIME(config=self.config,
reward_module=self.reward_module,
reference_module=self.reference_module,
reward_optimizer=self.reward_optimizer,
prime_loss_fn=self.config.prime_model.get('loss_type', 'ce'))
torch.cuda.empty_cache()
def _switch_chat_template(self, data: DataProto):
src_max_length = data.batch['attention_mask'].shape[-1]
src_tokenizer = self.input_tokenizer
target_tokenizer = self.tokenizer
rm_input_ids = []
rm_attention_mask = []
for i in range(data.batch.batch_size[0]):
# extract raw prompt
chat: list = data.non_tensor_batch['raw_prompt'][i].tolist()
# extract response
response_ids = data.batch['responses'][i]
response_length = response_ids.shape[-1]
valid_response_length = data.batch['attention_mask'][i][-response_length:].sum()
valid_response_ids = response_ids[:valid_response_length]
# decode
response = src_tokenizer.decode(valid_response_ids)
# remove bos and eos
response = response.replace(src_tokenizer.eos_token, '')
chat.append({'role': 'assistant', 'content': response})
prompt_with_chat_template = target_tokenizer.apply_chat_template(chat,
add_generation_prompt=False,
tokenize=False)
if self.rank == 0 and i == 0:
# for debugging purpose
print(f'Switch template. chat: {prompt_with_chat_template}')
# the maximum length is actually determined by the reward model itself
max_length = self.config.get('max_length', src_max_length)
if max_length is None:
max_length = src_max_length
input_ids, attention_mask = verl_F.tokenize_and_postprocess_data(
prompt=prompt_with_chat_template,
tokenizer=target_tokenizer,
max_length=max_length,
pad_token_id=target_tokenizer.pad_token_id,
left_pad=False, # right padding
truncation=self.config.get('truncation', 'right')) # truncate from the right
rm_input_ids.append(input_ids)
rm_attention_mask.append(attention_mask)
rm_input_ids = torch.cat(rm_input_ids, dim=0)
rm_attention_mask = torch.cat(rm_attention_mask, dim=0)
rm_position_ids = compute_position_id_with_mask(rm_attention_mask)
rm_inputs = {'input_ids': rm_input_ids, 'attention_mask': rm_attention_mask, 'position_ids': rm_position_ids}
return DataProto.from_dict(rm_inputs)
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_rm_score(self, data: DataProto):
n_samples=data.meta_info['n_samples']
beta=self.config.prime_model.get('beta_train', 0.05)
if self._do_switch_chat_template:
rm_data = self._switch_chat_template(data)
else:
rm_data=data
if self.update_dpo_type!='none':
if self._is_offload_optimizer:
load_fsdp_optimizer(optimizer=self.reward_optimizer, device_id=torch.cuda.current_device())
if self._is_offload_param:
load_fsdp_param_and_grad(module=self.reward_module,device_id=torch.cuda.current_device(),load_grad=self._is_offload_grad)
if self.reference_module is not None:
load_fsdp_param_and_grad(module=self.reference_module,device_id=torch.cuda.current_device(),load_grad=self._is_offload_grad)
token_level_scores, metrics = self.prm.update_policy(rm_data)
output=DataProto.from_dict(tensors = {'rm_scores': token_level_scores}, meta_info = {'metrics': metrics})
if self.update_dpo_type != 'none':
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.reward_optimizer)
self.reward_lr_scheduler.step()
if self._is_offload_param:
offload_fsdp_param_and_grad(module=self.reward_module, offload_grad=self._is_offload_grad)
if self.reference_module is not None:
offload_fsdp_param_and_grad(module=self.reference_module, offload_grad=self._is_offload_grad)
output = output.to('cpu')
torch.cuda.empty_cache()
return output
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(self, local_path, hdfs_path=None):
import torch
if self._is_offload_param:
load_fsdp_param_and_grad(module=self.reward_module,
device_id=torch.cuda.current_device(),
load_grad=self._is_offload_grad)
# TODO: support DCP and save sharded checkpoints
import torch.distributed
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType, FullStateDictConfig
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(self.reward_module, StateDictType.FULL_STATE_DICT, cfg):
state_dict = self.reward_module.state_dict()
if self.rank == 0:
print(f'Saving reward checkpoint to {local_path}')
os.makedirs(local_path, exist_ok=True)
self.reward_module._fsdp_wrapped_module.save_pretrained(local_path, state_dict=state_dict)
if hdfs_path is not None:
print(f'Uploading reward checkpoint to {hdfs_path}')
hdfs_io.makedirs(hdfs_path, exist_ok=True)
hdfs_io.copy(src=local_path, dst=hdfs_path)
torch.distributed.barrier()
if self._is_offload_param:
offload_fsdp_param_and_grad(module=self.reward_module, offload_grad=self._is_offload_grad)