File size: 68,345 Bytes
bcdf9fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 |
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The main entry point to run the PPO algorithm
"""
import logging
import os
import warnings
from typing import Union
import psutil
import torch
import torch.distributed
from codetiming import Timer
from omegaconf import DictConfig, open_dict
from torch.distributed.device_mesh import init_device_mesh
import verl.utils.torch_functional as verl_F
from verl import DataProto
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import Dispatch, register
from verl.utils import hf_processor, hf_tokenizer
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from verl.utils.debug import log_gpu_memory_usage
from verl.utils.flops_counter import FlopsCounter
from verl.utils.fs import copy_to_local
from verl.utils.fsdp_utils import (
CPUOffloadPolicy,
MixedPrecisionPolicy,
apply_fsdp2,
fsdp2_load_full_state_dict,
fsdp_version,
get_fsdp_wrap_policy,
get_init_weight_context_manager,
init_fn,
load_fsdp_model_to_gpu,
load_fsdp_optimizer,
offload_fsdp_model_to_cpu,
offload_fsdp_optimizer,
)
from verl.utils.import_utils import import_external_libs
from verl.utils.model import compute_position_id_with_mask
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
def create_device_mesh(world_size, fsdp_size):
if fsdp_size < 0 or fsdp_size >= world_size:
device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"])
else:
device_mesh = init_device_mesh("cuda", mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"])
return device_mesh
def get_sharding_strategy(device_mesh):
from torch.distributed.fsdp import ShardingStrategy
if device_mesh.ndim == 1:
sharding_strategy = ShardingStrategy.FULL_SHARD
elif device_mesh.ndim == 2:
sharding_strategy = ShardingStrategy.HYBRID_SHARD
else:
raise NotImplementedError(f"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2")
return sharding_strategy
class ActorRolloutRefWorker(Worker):
"""
This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy
or a hybrid engine based on the config.rollout
"""
def __init__(self, config: DictConfig, role: str):
super().__init__()
self.config = config
import torch.distributed
if not torch.distributed.is_initialized():
torch.distributed.init_process_group()
# build device mesh for FSDP
world_size = torch.distributed.get_world_size()
# TODO(sgm): support FSDP hybrid shard for larger model
self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=self.config.actor.fsdp_config.fsdp_size)
# build device mesh for Ulysses Sequence Parallel
self.ulysses_device_mesh = None
self.ulysses_sequence_parallel_size = self.config.actor.get("ulysses_sequence_parallel_size", 1)
dp = world_size // self.ulysses_sequence_parallel_size
if self.ulysses_sequence_parallel_size > 1:
self.ulysses_device_mesh = init_device_mesh("cuda", mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"])
self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
self.role = role
assert self.role in ["actor", "rollout", "ref", "actor_rollout", "actor_rollout_ref"]
self._is_actor = self.role in ["actor", "actor_rollout", "actor_rollout_ref"]
self._is_rollout = self.role in ["rollout", "actor_rollout", "actor_rollout_ref"]
self._is_ref = self.role in ["ref", "actor_rollout_ref"]
self._is_offload_param = False
self._is_offload_optimizer = False
if self._is_actor:
self._is_offload_param = self.config.actor.fsdp_config.get("param_offload", False)
self._is_offload_optimizer = self.config.actor.fsdp_config.get("optimizer_offload", False)
elif self._is_ref:
# TODO: it seems that manual offload is slowly than FSDP offload
self._is_offload_param = self.config.ref.fsdp_config.get("param_offload", False)
# normalize config
if self._is_actor:
self.config.actor.ppo_mini_batch_size *= self.config.rollout.n
self.config.actor.ppo_mini_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size
assert self.config.actor.ppo_mini_batch_size > 0, f"ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than 0 after normalization"
# micro bsz
if self.config.actor.ppo_micro_batch_size is not None:
self.config.actor.ppo_micro_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size
self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size
if self.config.actor.ppo_micro_batch_size_per_gpu is not None:
assert self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0, f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}"
assert self.config.actor.ppo_mini_batch_size // self.config.actor.ppo_micro_batch_size_per_gpu > 0, f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}"
# normalize rollout config
if self._is_rollout and self.config.rollout.log_prob_micro_batch_size is not None:
self.config.rollout.log_prob_micro_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size
self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size
# normalize ref config
if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None:
self.config.ref.log_prob_micro_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size
self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size
def _build_model_optimizer(
self,
model_path,
fsdp_config,
optim_config,
override_model_config,
use_remove_padding=False,
enable_gradient_checkpointing=False,
trust_remote_code=False,
use_liger=False,
role="actor",
):
from torch import optim
from torch.distributed.fsdp import CPUOffload, MixedPrecision
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq
from verl.utils.model import get_generation_config, print_model_size, update_model_config
from verl.utils.torch_dtypes import PrecisionType
assert role in ["actor", "ref"]
log_gpu_memory_usage(f"Before init {role} from HF AutoModel", logger=logger)
local_path = copy_to_local(model_path)
# note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect
# TODO(zhangchi.usc1992): 1. support create from random initialized model. 2. Support init with FSDP directly
self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
self.processor = None # hf_processor(local_path, trust_remote_code=trust_remote_code)
torch_dtype = fsdp_config.get("model_dtype", None)
if torch_dtype is None:
torch_dtype = torch.float32 if self._is_actor else torch.bfloat16
else:
torch_dtype = PrecisionType.to_dtype(torch_dtype)
# override model kwargs
actor_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)
self.generation_config = get_generation_config(local_path, trust_remote_code=trust_remote_code)
override_config_kwargs = {
"bos_token_id": self.tokenizer.bos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.pad_token_id,
}
override_config_kwargs.update(override_model_config)
update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs)
if self.rank == 0:
print(f"Model config after override: {actor_model_config}")
# NOTE(fix me): tie_word_embedding causes meta_tensor init to hang
init_context = get_init_weight_context_manager(use_meta_tensor=not actor_model_config.tie_word_embeddings, mesh=self.device_mesh)
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
if type(actor_model_config) in AutoModelForVision2Seq._model_mapping.keys():
actor_module_class = AutoModelForVision2Seq
else:
actor_module_class = AutoModelForCausalLM
actor_module = actor_module_class.from_pretrained(
pretrained_model_name_or_path=local_path,
torch_dtype=torch_dtype,
config=actor_model_config,
attn_implementation="flash_attention_2",
trust_remote_code=trust_remote_code,
)
if use_remove_padding or self.ulysses_sequence_parallel_size > 1:
from verl.models.transformers.monkey_patch import apply_monkey_patch
apply_monkey_patch(model=actor_module, ulysses_sp_size=self.ulysses_sequence_parallel_size)
# Apply Liger kernel to the model if use_liger is set to True
if use_liger:
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance
_apply_liger_kernel_to_instance(model=actor_module)
# some parameters may not in torch_dtype. TODO(zhangchi.usc1992) remove this after we switch to fsdp2
actor_module.to(torch_dtype)
if enable_gradient_checkpointing:
actor_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
torch.distributed.barrier()
if self.rank == 0:
print_model_size(actor_module)
log_gpu_memory_usage(f"After init {role} from HF AutoModel", logger=logger)
# We wrap FSDP for rollout as well
mixed_precision_config = fsdp_config.get("mixed_precision", None)
if mixed_precision_config is not None:
param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16"))
reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32"))
buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32"))
else:
param_dtype = torch.bfloat16
reduce_dtype = torch.float32
buffer_dtype = torch.float32
mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype)
auto_wrap_policy = get_fsdp_wrap_policy(module=actor_module, config=fsdp_config.get("wrap_policy", None))
if self._is_rollout and self.config.rollout.name == "hf":
# TODO(zhangchi.usc1992, shengguangming) fix me. Current, auto_wrap_policy causes HFRollout to hang in Gemma
auto_wrap_policy = None
print(f"wrap_policy: {auto_wrap_policy}")
fsdp_mesh = self.device_mesh
sharding_strategy = get_sharding_strategy(fsdp_mesh)
if self.config.model.get("load_param", False):
load_param_path = self.config.model.load_param_path
if load_param_path is None:
raise ValueError("load_param_path should not be None when load_param is True")
param_path = os.path.join(copy_to_local(load_param_path))
state_dict = torch.load(param_path, map_location="cpu")
actor_module.load_state_dict(state_dict,strict = True, assign=True)
print("\n" + "="*60)
print(f"✅✅✅ SUCCESS: Model loaded from: {param_path} ✅✅✅")
print("="*60 + "\n")
# TODO: add transformer policy
# We force reference policy to use CPUOffload to save memory.
# We force turn off CPUOffload for actor because it causes incorrect results when using grad accumulation
cpu_offload = None if role == "actor" else CPUOffload(offload_params=True)
fsdp_strategy = self.config.actor.strategy
if fsdp_strategy == "fsdp":
actor_module_fsdp = FSDP(
actor_module,
cpu_offload=cpu_offload,
param_init_fn=init_fn,
use_orig_params=False,
auto_wrap_policy=auto_wrap_policy,
device_id=torch.cuda.current_device(),
sharding_strategy=sharding_strategy, # zero3
mixed_precision=mixed_precision,
sync_module_states=True,
device_mesh=self.device_mesh,
forward_prefetch=False,
)
elif fsdp_strategy == "fsdp2":
assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True)
if role == "actor" and fsdp_config.offload_policy:
cpu_offload = CPUOffloadPolicy(pin_memory=True)
self._is_offload_param = False
self._is_offload_optimizer = False
else:
cpu_offload = None if role == "actor" else CPUOffloadPolicy(pin_memory=True)
fsdp_kwargs = {
"mesh": fsdp_mesh,
"mp_policy": mp_policy,
"offload_policy": cpu_offload,
"reshard_after_forward": fsdp_config.reshard_after_forward,
}
full_state = actor_module.state_dict()
apply_fsdp2(actor_module, fsdp_kwargs, fsdp_config)
fsdp2_load_full_state_dict(actor_module, full_state, fsdp_mesh, cpu_offload)
actor_module_fsdp = actor_module
else:
raise NotImplementedError(f"not implement {fsdp_strategy}")
log_gpu_memory_usage(f"After {role} FSDP init", logger=logger)
# TODO: add more optimizer args into config
if role == "actor" and optim_config is not None:
from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup
actor_optimizer = optim.AdamW(
actor_module_fsdp.parameters(),
lr=optim_config.lr,
betas=optim_config.get("betas", (0.9, 0.999)),
weight_decay=optim_config.get("weight_decay", 1e-2),
)
total_steps = optim_config.get("total_training_steps", 0)
num_warmup_steps = int(optim_config.get("lr_warmup_steps", -1))
warmup_style = optim_config.get("warmup_style", "constant")
if num_warmup_steps < 0:
num_warmup_steps_ratio = optim_config.get("lr_warmup_steps_ratio", 0.0)
num_warmup_steps = int(num_warmup_steps_ratio * total_steps)
print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}")
if warmup_style == "constant":
actor_lr_scheduler = get_constant_schedule_with_warmup(optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps)
elif warmup_style == "cosine":
actor_lr_scheduler = get_cosine_schedule_with_warmup(optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps)
else:
raise NotImplementedError(f"Warmup style {warmup_style} is not supported")
log_gpu_memory_usage(f"After {role} optimizer init", logger=logger)
else:
actor_optimizer = None
actor_lr_scheduler = None
return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config
def _build_rollout(self, trust_remote_code=False):
from torch.distributed.device_mesh import init_device_mesh
# TODO(sgm): support FSDP hybrid shard for larger model
infer_tp = self.config.rollout.tensor_model_parallel_size
dp = self.world_size // infer_tp
assert self.world_size % infer_tp == 0, f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}"
rollout_device_mesh = init_device_mesh("cuda", mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"])
rollout_name = self.config.rollout.name
if rollout_name == "hf":
from verl.workers.rollout import HFRollout
from verl.workers.sharding_manager.base import BaseShardingManager
rollout = HFRollout(module=self.actor_module_fsdp, config=self.config.rollout)
rollout_sharding_manager = BaseShardingManager()
# TODO: a sharding manager that do nothing?
elif rollout_name == "vllm":
from verl.workers.rollout.vllm_rollout import vllm_mode, vLLMRollout
from verl.workers.sharding_manager.fsdp_vllm import FSDPVLLMShardingManager
log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger)
local_path = copy_to_local(self.config.model.path)
if vllm_mode == "customized":
rollout = vLLMRollout(
actor_module=self.actor_module_fsdp,
config=self.config.rollout,
tokenizer=self.tokenizer,
model_hf_config=self.actor_model_config,
trust_remote_code=trust_remote_code,
)
elif vllm_mode == "spmd":
from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout
vllm_rollout_cls = vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout
rollout = vllm_rollout_cls(
model_path=local_path,
config=self.config.rollout,
tokenizer=self.tokenizer,
model_hf_config=self.actor_model_config,
device_mesh=rollout_device_mesh,
trust_remote_code=trust_remote_code,
)
else:
raise NotImplementedError("vllm_mode must be 'customized' or 'spmd'")
log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger)
if torch.distributed.get_world_size() == 1:
self.config.rollout.load_format = "dummy_hf"
rollout_sharding_manager = FSDPVLLMShardingManager(
module=self.actor_module_fsdp,
inference_engine=rollout.inference_engine,
model_config=self.actor_model_config,
full_params="hf" in self.config.rollout.load_format,
device_mesh=rollout_device_mesh,
offload_param=self._is_offload_param,
)
log_gpu_memory_usage("After building sharding manager", logger=logger)
elif rollout_name == "sglang":
from verl.workers.rollout.sglang_rollout import SGLangRollout
# NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to
# SGLang's model_runner would check CUDA device capability. However, due to verl's setting,
# the main process of ray can not find any CUDA device, which would potentially lead to:
# "RuntimeError: No CUDA GPUs are available".
# For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and
# we import it here use the abs path.
# check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76
from verl.workers.sharding_manager.fsdp_sglang import FSDPSGLangShardingManager
log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger)
local_path = copy_to_local(self.config.model.path)
rollout = SGLangRollout(
actor_module=local_path,
config=self.config.rollout,
tokenizer=self.tokenizer,
model_hf_config=self.actor_model_config,
trust_remote_code=trust_remote_code,
)
log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger)
if torch.distributed.get_world_size() == 1:
self.config.rollout.load_format = "dummy_hf"
rollout_sharding_manager = FSDPSGLangShardingManager(
module=self.actor_module_fsdp,
inference_engine=rollout.inference_engine,
model_config=self.actor_model_config,
full_params="hf" in self.config.rollout.load_format,
device_mesh=rollout_device_mesh,
offload_param=self._is_offload_param,
)
log_gpu_memory_usage("After building sharding manager", logger=logger)
elif rollout_name == "sglang_async":
from verl.workers.rollout.sglang_rollout import AsyncSGLangRollout
from verl.workers.sharding_manager.fsdp_sglang import FSDPAsyncSGLangShardingManager
log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=None)
rollout = AsyncSGLangRollout(
actor_module=self.config.model.path,
config=self.config.rollout,
tokenizer=self.tokenizer,
model_hf_config=self.actor_model_config,
trust_remote_code=trust_remote_code,
)
log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=None)
if torch.distributed.get_world_size() == 1:
self.config.rollout.load_format = "dummy_hf"
rollout_sharding_manager = FSDPAsyncSGLangShardingManager(
module=self.actor_module_fsdp,
inference_engine=rollout._engine,
model_config=self.actor_model_config,
full_params="hf" in self.config.rollout.load_format,
device_mesh=rollout_device_mesh,
)
log_gpu_memory_usage("After building sharding manager", logger=None)
else:
raise NotImplementedError(f"Rollout name: {self.config.rollout.name} is not supported")
return rollout, rollout_sharding_manager
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
from verl.workers.actor import DataParallelPPOActor
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get("external_lib", None))
from omegaconf import OmegaConf
override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))
use_remove_padding = self.config.model.get("use_remove_padding", False)
if self._is_actor or self._is_rollout:
# we need the model for actor and rollout
if self._is_actor:
optim_config = self.config.actor.optim
fsdp_config = self.config.actor.fsdp_config
else:
optim_config = None
fsdp_config = OmegaConf.create()
self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = self._build_model_optimizer(
model_path=self.config.model.path,
fsdp_config=fsdp_config,
optim_config=optim_config,
override_model_config=override_model_config,
use_remove_padding=use_remove_padding,
enable_gradient_checkpointing=self.config.model.get("enable_gradient_checkpointing", False),
trust_remote_code=self.config.model.get("trust_remote_code", False),
use_liger=self.config.model.get("use_liger", False),
role="actor",
)
# get the original unwrapped module
if fsdp_version(self.actor_module_fsdp) == 1:
self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
log_gpu_memory_usage("After offload actor model during init", logger=logger)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
log_gpu_memory_usage("After offload actor optimizer during init", logger=logger)
# load from checkpoint
if self._is_actor:
OmegaConf.set_struct(self.config.actor, True)
with open_dict(self.config.actor):
self.config.actor.use_remove_padding = use_remove_padding
self.actor = DataParallelPPOActor(config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer)
if self._is_rollout:
self.rollout, self.rollout_sharding_manager = self._build_rollout(trust_remote_code=self.config.model.get("trust_remote_code", False))
if self._is_ref:
self.ref_module_fsdp = self._build_model_optimizer(
model_path=self.config.model.path,
fsdp_config=self.config.ref.fsdp_config,
optim_config=None,
override_model_config=override_model_config,
use_remove_padding=use_remove_padding,
trust_remote_code=self.config.model.get("trust_remote_code", False),
use_liger=self.config.model.get("use_liger", False),
role="ref",
)[0]
OmegaConf.set_struct(self.config.ref, True)
with open_dict(self.config.ref):
self.config.ref.use_remove_padding = use_remove_padding
self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp)
if self._is_actor:
self.flops_counter = FlopsCounter(self.actor_model_config)
self.checkpoint_manager = FSDPCheckpointManager(
model=self.actor_module_fsdp,
optimizer=self.actor.actor_optimizer,
lr_scheduler=self.actor_lr_scheduler,
processing_class=self.processor if self.processor is not None else self.tokenizer,
checkpoint_contents=self.config.actor.checkpoint.contents,
)
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_actor(self, data: DataProto):
# Support all hardwares
data = data.to(torch.cuda.current_device())
assert self._is_actor
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
if self._is_offload_optimizer:
load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=torch.cuda.current_device())
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data=data)
# perform training
with Timer(name="update_policy", logger=None) as timer:
metrics = self.actor.update_policy(data=data)
delta_time = timer.last
global_num_tokens = data.meta_info["global_token_num"]
estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
metrics["perf/mfu/actor"] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size
metrics["perf/max_memory_allocated_gb"] = torch.cuda.max_memory_allocated() / (1024**3)
metrics["perf/max_memory_reserved_gb"] = torch.cuda.max_memory_reserved() / (1024**3)
metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3)
self.actor_lr_scheduler.step()
lr = self.actor_lr_scheduler.get_last_lr()[0]
metrics["actor/lr"] = lr
# TODO: here, we should return all metrics
output = DataProto(meta_info={"metrics": metrics})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
output = output.to("cpu")
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
log_gpu_memory_usage("After offload actor model during update_actor", logger=logger)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
log_gpu_memory_usage("After offload actor optimizer during update_actor", logger=logger)
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def generate_sequences(self, prompts: DataProto):
# Support all hardwares
prompts = prompts.to(torch.cuda.current_device())
assert self._is_rollout
meta_info = {
"eos_token_id": self.generation_config.eos_token_id if self.generation_config is not None else self.tokenizer.eos_token_id,
"pad_token_id": self.generation_config.pad_token_id if self.generation_config is not None else self.tokenizer.pad_token_id,
"do_sample": self.config.rollout.do_sample,
}
prompts.meta_info.update(meta_info)
with self.rollout_sharding_manager:
log_gpu_memory_usage("After entering rollout sharding manager", logger=logger)
prompts = self.rollout_sharding_manager.preprocess_data(prompts)
if self.config.rollout.name == "sglang_async":
from verl.workers.rollout.sglang_rollout import AsyncSGLangRollout
if isinstance(self.rollout, AsyncSGLangRollout) and hasattr(self.rollout, "_tool_schemas") and len(self.rollout._tool_schemas) > 0:
output = self.rollout.generate_sequences_with_tools(prompts=prompts)
else:
output = self.rollout.generate_sequences(prompts=prompts)
else:
output = self.rollout.generate_sequences(prompts=prompts)
log_gpu_memory_usage("After rollout generation", logger=logger)
output = self.rollout_sharding_manager.postprocess_data(output)
output = output.to("cpu")
# clear kv cache
torch.cuda.empty_cache()
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_log_prob(self, data: DataProto):
assert self._is_actor
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
# Support all hardwares
data = data.to(torch.cuda.current_device())
# we should always recompute old_log_probs when it is HybridEngine
data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu
data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu
data.meta_info["use_dynamic_bsz"] = self.config.rollout.log_prob_use_dynamic_bsz
data.meta_info["temperature"] = self.config.rollout.temperature
# perform recompute log_prob
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data)
output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True)
output = DataProto.from_dict(
tensors={"old_log_probs": output, "entropys": entropys},
meta_info={"temperature": self.config.rollout.temperature},
)
output = self.ulysses_sharding_manager.postprocess_data(output)
output = output.to("cpu")
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
if self.world_size > 1 and fsdp_version(self.actor.actor_module) == 1:
self.actor.actor_module._handle.reshard(True)
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
log_gpu_memory_usage("After offload actor model during compute_log_prob", logger=logger)
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_ref_log_prob(self, data: DataProto):
assert self._is_ref
# Support all hardwares
data = data.to(torch.cuda.current_device())
micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu
data.meta_info["micro_batch_size"] = micro_batch_size
data.meta_info["temperature"] = self.config.rollout.temperature
data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu
data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data)
output, _ = self.ref_policy.compute_log_prob(data=data, calculate_entropy=False)
output = DataProto.from_dict(tensors={"ref_log_prob": output})
output = self.ulysses_sharding_manager.postprocess_data(output)
output = output.to("cpu")
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
if self.world_size > 1 and fsdp_version(self.ref_policy.actor_module) == 1:
self.ref_policy.actor_module._handle.reshard(True)
return output
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
# only support save and load ckpt for actor
assert self._is_actor
import torch
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
self.checkpoint_manager.save_checkpoint(local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep)
torch.distributed.barrier()
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False):
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
self.checkpoint_manager.load_checkpoint(local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load)
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
if self._is_offload_optimizer:
offload_fsdp_optimizer(self.actor_optimizer)
class CriticWorker(Worker):
def __init__(self, config):
super().__init__()
import torch.distributed
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
self.config = config
# build device mesh for Ulysses Sequence Parallel
world_size = torch.distributed.get_world_size()
from torch.distributed.device_mesh import init_device_mesh
fsdp_size = self.config.model.fsdp_config.fsdp_size
self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size)
self.ulysses_device_mesh = None
self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1)
dp = world_size // self.ulysses_sequence_parallel_size
if self.ulysses_sequence_parallel_size > 1:
self.ulysses_device_mesh = init_device_mesh("cuda", mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"])
self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
# set FSDP offload params
self._is_offload_param = self.config.model.fsdp_config.param_offload
self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload
# normalize config
self.config.ppo_mini_batch_size *= self.config.rollout_n
self.config.ppo_mini_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size
if self.config.ppo_micro_batch_size is not None:
self.config.ppo_micro_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size
self.config.forward_micro_batch_size //= torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size
self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size
self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size
if self.config.ppo_micro_batch_size_per_gpu is not None:
assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0, f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}"
assert self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu > 0, f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}"
def _build_critic_model_optimizer(self, config):
# the following line is necessary
from torch import optim
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision
from verl.utils.model import print_model_size
from verl.utils.torch_dtypes import PrecisionType
local_path = copy_to_local(config.model.path)
# note that the tokenizer between actor and critic may be different. So override tokenizer info with actor info
# using random initialized model from any architecture. May not be the same as Actor.
tokenizer_path = copy_to_local(config.model.tokenizer_path)
self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False))
self.processor = hf_processor(tokenizer_path, trust_remote_code=config.model.get("trust_remote_code", False))
from omegaconf import OmegaConf
override_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))
override_config_kwargs = {
"bos_token_id": self.tokenizer.bos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.pad_token_id,
}
override_config_kwargs.update(override_config)
if self.rank == 0:
print(f"Critic overriding config {override_config_kwargs}")
torch_dtype = self.config.model.fsdp_config.get("model_dtype", "fp32")
torch_dtype = PrecisionType.to_dtype(torch_dtype)
from transformers import AutoConfig, AutoModelForTokenClassification
critic_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=config.model.get("trust_remote_code", False))
critic_model_config.num_labels = 1
init_context = get_init_weight_context_manager(use_meta_tensor=not critic_model_config.tie_word_embeddings, mesh=self.device_mesh)
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
critic_model_config.classifier_dropout = 0.0
critic_model_config.hidden_dropout = "0"
critic_module = AutoModelForTokenClassification.from_pretrained(
pretrained_model_name_or_path=local_path,
torch_dtype=torch_dtype,
config=critic_model_config,
attn_implementation="flash_attention_2",
trust_remote_code=config.model.get("trust_remote_code", False),
)
use_remove_padding = config.model.get("use_remove_padding", False)
if use_remove_padding or self.ulysses_sequence_parallel_size > 1:
from verl.models.transformers.monkey_patch import apply_monkey_patch
apply_monkey_patch(model=critic_module, ulysses_sp_size=self.ulysses_sequence_parallel_size)
# some parameters may not in torch_dtype
critic_module.to(torch_dtype)
if config.model.get("enable_gradient_checkpointing", False):
critic_module.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
if self.rank == 0:
print_model_size(critic_module)
self.critic_model_config = critic_model_config
fsdp_config = self.config.model.fsdp_config
mixed_precision_config = fsdp_config.get("mixed_precision", None)
if mixed_precision_config is not None:
param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16"))
reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32"))
buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32"))
else:
param_dtype = torch.bfloat16
reduce_dtype = torch.float32
buffer_dtype = torch.float32
mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype)
auto_wrap_policy = get_fsdp_wrap_policy(module=critic_module, config=self.config.model.fsdp_config.wrap_policy)
log_gpu_memory_usage("Before critic FSDP", logger=None)
fsdp_mesh = self.device_mesh
sharding_strategy = get_sharding_strategy(fsdp_mesh)
# Note: We force turn off CPUOffload for critic because it causes incorrect results when using grad accumulation
if config.strategy == "fsdp":
critic_module = FSDP(
critic_module,
param_init_fn=init_fn,
use_orig_params=False,
auto_wrap_policy=auto_wrap_policy,
device_id=torch.cuda.current_device(),
sharding_strategy=sharding_strategy,
mixed_precision=mixed_precision,
sync_module_states=True,
forward_prefetch=False,
device_mesh=self.device_mesh,
cpu_offload=None,
)
elif config.strategy == "fsdp2":
assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True)
offload_policy = None
if fsdp_config.offload_policy:
self._is_offload_param = False
self._is_offload_optimizer = False
offload_policy = CPUOffloadPolicy(pin_memory=True)
fsdp_kwargs = {
"mesh": fsdp_mesh,
"mp_policy": mp_policy,
"offload_policy": offload_policy,
"reshard_after_forward": fsdp_config.reshard_after_forward,
}
full_state = critic_module.state_dict()
apply_fsdp2(critic_module, fsdp_kwargs, fsdp_config)
fsdp2_load_full_state_dict(critic_module, full_state, fsdp_mesh, offload_policy)
else:
raise NotImplementedError(f"Unknown strategy {config.strategy}")
log_gpu_memory_usage("After critic FSDP", logger=None)
critic_optimizer = optim.AdamW(
critic_module.parameters(),
lr=config.optim.lr,
betas=config.optim.get("betas", (0.9, 0.999)),
weight_decay=config.optim.get("weight_decay", 1e-2),
)
total_steps = config.optim.get("total_training_steps", 0)
num_warmup_steps = int(config.optim.get("lr_warmup_steps", -1))
warmup_style = config.optim.get("warmup_style", "constant")
if num_warmup_steps < 0:
num_warmup_steps_ratio = config.optim.get("lr_warmup_steps_ratio", 0.0)
num_warmup_steps = int(num_warmup_steps_ratio * total_steps)
print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}")
from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup
if warmup_style == "constant":
critic_lr_scheduler = get_constant_schedule_with_warmup(optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps)
elif warmup_style == "cosine":
critic_lr_scheduler = get_cosine_schedule_with_warmup(optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps)
else:
raise NotImplementedError(f"Warmup style {warmup_style} is not supported")
return critic_module, critic_optimizer, critic_lr_scheduler
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get("external_lib", None))
from verl.workers.critic import DataParallelPPOCritic
self.critic_module, self.critic_optimizer, self.critic_lr_scheduler = self._build_critic_model_optimizer(self.config)
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.critic_module)
log_gpu_memory_usage("After offload critic model during init", logger=logger)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.critic_optimizer)
log_gpu_memory_usage("After offload critic optimizer during init", logger=logger)
self.critic = DataParallelPPOCritic(config=self.config, critic_module=self.critic_module, critic_optimizer=self.critic_optimizer)
self.flops_counter = FlopsCounter(self.critic_model_config)
self.checkpoint_manager = FSDPCheckpointManager(
model=self.critic_module,
optimizer=self.critic_optimizer,
lr_scheduler=self.critic_lr_scheduler,
processing_class=self.processor if self.processor is not None else self.tokenizer,
checkpoint_contents=self.config.checkpoint.contents,
)
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_values(self, data: DataProto):
# Support all hardwares
data = data.to(torch.cuda.current_device())
if self._is_offload_param:
load_fsdp_model_to_gpu(self.critic_module)
micro_batch_size = self.config.forward_micro_batch_size_per_gpu
data.meta_info["micro_batch_size"] = micro_batch_size
data.meta_info["max_token_len"] = self.config.forward_max_token_len_per_gpu
data.meta_info["use_dynamic_bsz"] = self.config.use_dynamic_bsz
# perform forward computation
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data=data)
values = self.critic.compute_values(data=data)
output = DataProto.from_dict(tensors={"values": values})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
output = output.to("cpu")
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.critic_module)
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_critic(self, data: DataProto):
# Support all hardwares
data = data.to(torch.cuda.current_device())
if self._is_offload_param:
load_fsdp_model_to_gpu(self.critic_module)
if self._is_offload_optimizer:
load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=torch.cuda.current_device())
# perform forward computation
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data=data)
with Timer(name="update_critic", logger=None) as timer:
metrics = self.critic.update_critic(data=data)
delta_time = timer.last
global_num_tokens = data.meta_info["global_token_num"]
estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
metrics["perf/mfu/critic"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size
self.critic_lr_scheduler.step()
lr = self.critic_lr_scheduler.get_last_lr()[0]
metrics["critic/lr"] = lr
output = DataProto(batch=None, meta_info={"metrics": metrics})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.critic_module)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.critic_optimizer)
output = output.to("cpu")
return output
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
import torch
if self._is_offload_param:
load_fsdp_model_to_gpu(self.critic_module)
self.checkpoint_manager.save_checkpoint(local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep)
torch.distributed.barrier()
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.critic_module)
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):
import torch
if self._is_offload_param:
load_fsdp_model_to_gpu(self.critic_module)
self.checkpoint_manager.load_checkpoint(local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load)
torch.distributed.barrier()
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.critic_module)
if self._is_offload_optimizer:
offload_fsdp_optimizer(self.critic_optimizer)
# TODO(sgm): we may need to extract it to dp_reward_model.py
class RewardModelWorker(Worker):
"""
Note that we only implement the reward model that is subclass of AutoModelForTokenClassification.
"""
def __init__(self, config):
super().__init__()
import torch.distributed
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
self.config = config
# build device mesh for Ulysses Sequence Parallel
world_size = torch.distributed.get_world_size()
from torch.distributed.device_mesh import init_device_mesh
fsdp_size = self.config.model.fsdp_config.fsdp_size
self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size)
self.ulysses_device_mesh = None
self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1)
dp = world_size // self.ulysses_sequence_parallel_size
if self.ulysses_sequence_parallel_size > 1:
self.ulysses_device_mesh = init_device_mesh("cuda", mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"])
self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
self.use_remove_padding = self.config.model.get("use_remove_padding", False)
# normalize config
if self.config.micro_batch_size is not None:
self.config.micro_batch_size //= torch.distributed.get_world_size()
self.config.micro_batch_size_per_gpu = self.config.micro_batch_size
def _build_model(self, config):
# the following line is necessary
from torch.distributed.fsdp import CPUOffload
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import AutoConfig, AutoModelForTokenClassification
# download the checkpoint from hdfs
local_path = copy_to_local(config.model.path)
if self.config.model.input_tokenizer is None:
self._do_switch_chat_template = False
else:
self._do_switch_chat_template = True
input_tokenizer_local_path = copy_to_local(config.model.input_tokenizer)
self.input_tokenizer = hf_tokenizer(input_tokenizer_local_path, trust_remote_code=config.model.get("trust_remote_code", False))
self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.model.get("trust_remote_code", False))
trust_remote_code = config.model.get("trust_remote_code", False)
model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)
model_config.num_labels = 1
# note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect
init_context = get_init_weight_context_manager(use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh)
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
model_config.classifier_dropout = 0.0
reward_module = AutoModelForTokenClassification.from_pretrained(
pretrained_model_name_or_path=local_path,
config=model_config,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
trust_remote_code=trust_remote_code,
)
if config.model.get("use_remove_padding", False) or self.ulysses_sequence_parallel_size > 1:
from verl.models.transformers.monkey_patch import apply_monkey_patch
apply_monkey_patch(model=reward_module, ulysses_sp_size=self.ulysses_sequence_parallel_size)
reward_module.to(torch.bfloat16)
auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config)
fsdp_mesh = self.device_mesh
sharding_strategy = get_sharding_strategy(fsdp_mesh)
if config.strategy == "fsdp":
reward_module = FSDP(
reward_module,
param_init_fn=init_fn,
use_orig_params=False,
auto_wrap_policy=auto_wrap_policy,
device_id=torch.cuda.current_device(),
sharding_strategy=sharding_strategy, # zero3
sync_module_states=True,
cpu_offload=CPUOffload(offload_params=True),
forward_prefetch=False,
device_mesh=self.device_mesh,
)
elif config.strategy == "fsdp2":
assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
cpu_offload = CPUOffloadPolicy(pin_memory=True)
fsdp_kwargs = {
"mesh": fsdp_mesh,
"offload_policy": cpu_offload,
"reshard_after_forward": config.model.fsdp_config.reshard_after_forward,
}
full_state = reward_module.state_dict()
apply_fsdp2(reward_module, fsdp_kwargs, config.model.fsdp_config)
fsdp2_load_full_state_dict(reward_module, full_state, fsdp_mesh, cpu_offload)
else:
raise NotImplementedError(f"Unknown strategy: {config.strategy}")
return reward_module
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get("external_lib", None))
self.reward_module = self._build_model(config=self.config)
def _forward_micro_batch(self, micro_batch):
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16):
input_ids = micro_batch["input_ids"]
batch_size, seqlen = input_ids.shape
attention_mask = micro_batch["attention_mask"]
position_ids = micro_batch["position_ids"]
if self.use_remove_padding:
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices).transpose(0, 1)
# pad and slice the inputs if sp > 1
if self.ulysses_sequence_parallel_size > 1:
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size)
# only pass input_ids and position_ids to enable flash_attn_varlen
output = self.reward_module(input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, use_cache=False) # prevent model thinks we are generating
reward_rmpad = output.logits
reward_rmpad = reward_rmpad.squeeze(0) # (total_nnz)
# gather output if sp > 1
if self.ulysses_sequence_parallel_size > 1:
reward_rmpad = gather_outpus_and_unpad(reward_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size)
# pad it back
rm_score = pad_input(reward_rmpad, indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1)
else:
output = self.reward_module(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False)
rm_score = output.logits # (batch_size, seq_len, 1)
rm_score = rm_score.squeeze(-1)
# extract the result of the last valid token
eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,)
rm_score = rm_score[torch.arange(batch_size), eos_mask_idx]
return rm_score
def _expand_to_token_level(self, data: DataProto, scores: torch.Tensor):
batch_size = data.batch.batch_size[0]
# expand as token_level_reward
attention_mask = data.batch["attention_mask"]
position_ids = data.batch["position_ids"]
response_length = data.batch["responses"].shape[-1]
eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,)
token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype) # (bsz, seqlen)
token_level_scores[torch.arange(batch_size), eos_mask_idx] = scores
# select the response part
token_level_scores = token_level_scores[:, -response_length:]
return token_level_scores
def _switch_chat_template(self, data: DataProto):
src_max_length = data.batch["attention_mask"].shape[-1]
src_tokenizer = self.input_tokenizer
target_tokenizer = self.tokenizer
rm_input_ids = []
rm_attention_mask = []
for i in range(data.batch.batch_size[0]):
# extract raw prompt
if isinstance(data.non_tensor_batch["raw_prompt"][i], list):
chat: list = data.non_tensor_batch["raw_prompt"][i]
else:
chat: list = data.non_tensor_batch["raw_prompt"][i].tolist()
# extract response
response_ids = data.batch["responses"][i]
response_length = response_ids.shape[-1]
valid_response_length = data.batch["attention_mask"][i][-response_length:].sum()
valid_response_ids = response_ids[:valid_response_length]
# decode
response = src_tokenizer.decode(valid_response_ids)
# remove bos and eos
response = response.replace(src_tokenizer.eos_token, "")
chat.append({"role": "assistant", "content": response})
prompt_with_chat_template = target_tokenizer.apply_chat_template(chat, add_generation_prompt=False, tokenize=False)
if self.rank == 0 and i == 0:
# for debugging purpose
print(f"Switch template. chat: {prompt_with_chat_template}")
# the maximum length is actually determined by the reward model itself
max_length = self.config.get("max_length", src_max_length)
if max_length is None:
max_length = src_max_length
model_inputs = target_tokenizer(prompt_with_chat_template, return_tensors="pt", add_special_tokens=False)
input_ids, attention_mask = verl_F.postprocess_data(
input_ids=model_inputs["input_ids"],
attention_mask=model_inputs["attention_mask"],
max_length=max_length,
pad_token_id=target_tokenizer.pad_token_id,
left_pad=False, # right padding
truncation=self.config.get("truncation", "right"),
) # truncate from the right
rm_input_ids.append(input_ids)
rm_attention_mask.append(attention_mask)
rm_input_ids = torch.cat(rm_input_ids, dim=0)
rm_attention_mask = torch.cat(rm_attention_mask, dim=0)
rm_position_ids = compute_position_id_with_mask(rm_attention_mask)
rm_inputs = {"input_ids": rm_input_ids, "attention_mask": rm_attention_mask, "position_ids": rm_position_ids}
return DataProto.from_dict(rm_inputs)
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_rm_score(self, data: DataProto):
import itertools
from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches
# Support all hardwares
data = data.to(torch.cuda.current_device())
if self._do_switch_chat_template:
rm_data = self._switch_chat_template(data)
else:
rm_input_ids = data.batch["input_ids"]
rm_attention_mask = data.batch["attention_mask"]
rm_position_ids = data.batch["position_ids"]
rm_inputs = {
"input_ids": rm_input_ids,
"attention_mask": rm_attention_mask,
"position_ids": rm_position_ids,
}
rm_data = DataProto.from_dict(rm_inputs)
# Support all hardwares
rm_data.batch = rm_data.batch.to(torch.cuda.current_device())
# perform forward computation
with self.ulysses_sharding_manager:
rm_data = self.ulysses_sharding_manager.preprocess_data(data=rm_data)
data = self.ulysses_sharding_manager.preprocess_data(data=data)
use_dynamic_bsz = self.config.use_dynamic_bsz
if use_dynamic_bsz:
max_token_len = self.config.forward_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
micro_batches, indices = rearrange_micro_batches(batch=rm_data.batch, max_token_len=max_token_len)
else:
micro_batches = rm_data.batch.split(self.config.micro_batch_size_per_gpu)
output = []
for micro_batch in micro_batches:
rm_score = self._forward_micro_batch(micro_batch)
output.append(rm_score)
scores = torch.cat(output, dim=0) # (batch_size)
if use_dynamic_bsz:
indices = list(itertools.chain.from_iterable(indices))
assert len(indices) == scores.size(0), f"{len(indices)} vs. {scores.size()}"
revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
scores = scores[revert_indices]
token_level_scores = self._expand_to_token_level(data, scores)
# Note that this is only the scores, may not be the final rewards used to train RL
output = DataProto.from_dict(tensors={"rm_scores": token_level_scores})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
self.reward_module._handle.reshard(True)
output = output.to("cpu")
return output
# ================================= Async related workers =================================
class AsyncActorRolloutRefWorker(ActorRolloutRefWorker):
def _build_rollout(self, trust_remote_code=False):
rollout, rollout_sharding_manager = super()._build_rollout(trust_remote_code)
# NOTE: rollout is not actually initialized here, it's deferred
# to be initialized by AsyncvLLMServer.
self.vllm_tp_size = self.config.rollout.tensor_model_parallel_size
self.vllm_dp_rank = int(os.environ["RANK"]) // self.vllm_tp_size
self.vllm_tp_rank = int(os.environ["RANK"]) % self.vllm_tp_size
# used for sleep/wake_up
rollout.sharding_manager = rollout_sharding_manager
return rollout, rollout_sharding_manager
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def generate_sequences(self, prompts: DataProto):
raise NotImplementedError("AsyncActorRolloutRefWorker does not support generate_sequences")
@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)
def execute_method(self, method: Union[str, bytes], *args, **kwargs):
"""Called by ExternalRayDistributedExecutor collective_rpc."""
if self.vllm_tp_rank == 0 and method != "execute_model":
print(f"[DP={self.vllm_dp_rank},TP={self.vllm_tp_rank}] execute_method: {method if isinstance(method, str) else 'Callable'}")
return self.rollout.execute_method(method, *args, **kwargs)
|