Add files using upload-large-folder tool
Browse files- code/RL_model/verl/verl_train/tests/single_controller/__init__.py +13 -0
- code/RL_model/verl/verl_train/tests/single_controller/test_auto_padding_on_cpu.py +152 -0
- code/RL_model/verl/verl_train/tests/single_controller/test_colocated_workers.py +86 -0
- code/RL_model/verl/verl_train/tests/single_controller/test_colocated_workers_fused.py +86 -0
- code/RL_model/verl/verl_train/tests/single_controller/test_data_transfer.py +109 -0
- code/RL_model/verl/verl_train/tests/single_controller/test_decorator_on_cpu.py +200 -0
- code/RL_model/verl/verl_train/tests/single_controller/test_device_mesh_register.py +158 -0
- code/RL_model/verl/verl_train/tests/single_controller/test_driverfunc_to_worker.py +85 -0
- code/RL_model/verl/verl_train/tests/single_controller/test_fused_workers_on_cpu.py +90 -0
- code/RL_model/verl/verl_train/tests/single_controller/test_high_level_scheduling_api.py +103 -0
- code/RL_model/verl/verl_train/tests/single_controller/test_rvdz.py +51 -0
- code/RL_model/verl/verl_train/tests/single_controller/test_worker_group_torch.py +116 -0
- code/RL_model/verl/verl_train/tests/special_e2e/README.md +1 -0
- code/RL_model/verl/verl_train/tests/utils/test_activation_offload.py +175 -0
- code/RL_model/verl/verl_train/tests/utils/test_check_ipc_version_support_on_npu.py +231 -0
- code/RL_model/verl/verl_train/tests/utils/test_config_on_cpu.py +97 -0
- code/RL_model/verl/verl_train/tests/utils/test_flops_counter.py +480 -0
- code/RL_model/verl/verl_train/tests/utils/test_fs_on_cpu.py +94 -0
- code/RL_model/verl/verl_train/tests/utils/test_groupwise.py +98 -0
- code/RL_model/verl/verl_train/tests/utils/test_import_utils_on_cpu.py +97 -0
- code/RL_model/verl/verl_train/tests/utils/test_linear_cross_entropy.py +361 -0
- code/RL_model/verl/verl_train/tests/utils/test_mlflow_key_sanitization.py +64 -0
- code/RL_model/verl/verl_train/tests/utils/test_model_on_cpu.py +52 -0
- code/RL_model/verl/verl_train/tests/utils/test_nvtx_profile.py +168 -0
- code/RL_model/verl/verl_train/tests/utils/test_rollout_skip_on_cpu.py +142 -0
- code/RL_model/verl/verl_train/tests/utils/test_rollout_trace_on_cpu.py +246 -0
- code/RL_model/verl/verl_train/tests/utils/test_seqlen_balancing.py +278 -0
- code/RL_model/verl/verl_train/tests/utils/test_shared_memory.py +260 -0
- code/RL_model/verl/verl_train/tests/utils/test_special_linear_cross_entropy_tp.py +514 -0
- code/RL_model/verl/verl_train/tests/utils/test_special_mstx_profile.py +274 -0
- code/RL_model/verl/verl_train/tests/utils/test_temp_env_on_cpu.py +143 -0
- code/RL_model/verl/verl_train/tests/utils/test_timeout_decorator_cpu.py +238 -0
- code/RL_model/verl/verl_train/tests/utils/test_torch_functional.py +152 -0
code/RL_model/verl/verl_train/tests/single_controller/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
code/RL_model/verl/verl_train/tests/single_controller/test_auto_padding_on_cpu.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import ray
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
from verl import DataProto
|
| 20 |
+
from verl.protocol import DataProtoConfig
|
| 21 |
+
from verl.single_controller.base import Worker
|
| 22 |
+
from verl.single_controller.base.decorator import Dispatch, register
|
| 23 |
+
from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
|
| 24 |
+
|
| 25 |
+
# or set env var VERL_AUTO_PADDING = "1" / "true"
|
| 26 |
+
DataProtoConfig.auto_padding = True
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@ray.remote
|
| 30 |
+
class Actor(Worker):
|
| 31 |
+
def __init__(self) -> None:
|
| 32 |
+
super().__init__()
|
| 33 |
+
|
| 34 |
+
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
|
| 35 |
+
def add(self, data: DataProto):
|
| 36 |
+
data.batch["a"] += self.rank
|
| 37 |
+
return data
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def test_auto_padding():
|
| 41 |
+
ray.init(num_cpus=100)
|
| 42 |
+
|
| 43 |
+
chunk_size = 4
|
| 44 |
+
actor_cls = RayClassWithInitArgs(cls=Actor)
|
| 45 |
+
resource_pool = RayResourcePool(process_on_nodes=[chunk_size], use_gpu=False)
|
| 46 |
+
actor_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=actor_cls)
|
| 47 |
+
|
| 48 |
+
# test locally first
|
| 49 |
+
for test_size in range(4, 20):
|
| 50 |
+
local_data = DataProto.from_dict({"a": torch.zeros(test_size)}, {"na": np.zeros(test_size, dtype=object)})
|
| 51 |
+
# print(f"before padding, local_data = {local_data}")
|
| 52 |
+
padding_size = (chunk_size - (test_size % chunk_size)) if (test_size % chunk_size > 0) else 0
|
| 53 |
+
local_data.padding(padding_size)
|
| 54 |
+
# print(f"after padding, local_data = {local_data}")
|
| 55 |
+
assert len(local_data) == len(local_data) + len(local_data) % chunk_size, (
|
| 56 |
+
f"expecting padded length to be {len(local_data) + len(local_data) % chunk_size}, but got {len(local_data)}"
|
| 57 |
+
)
|
| 58 |
+
chunked = local_data.chunk(chunk_size)
|
| 59 |
+
assert len(chunked) == chunk_size, f"during test_size = {test_size}, expecting {chunk_size}, got {chunked}"
|
| 60 |
+
for dp in chunked:
|
| 61 |
+
assert len(dp) == test_size // chunk_size + bool(test_size % chunk_size), (
|
| 62 |
+
f"test size = {test_size}, expecting dp to be length of "
|
| 63 |
+
f"{test_size // chunk_size + bool(test_size % chunk_size)}, but got {len(dp)}: {dp} {chunked}"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# test with RayWorkerGroup method decorated as dispatch_mode=Dispatch.DP_COMPUTE_PROTO
|
| 67 |
+
data = DataProto.from_dict({"a": torch.zeros(10)}, {"na": np.array([str(i) for i in range(10)], dtype=object)})
|
| 68 |
+
output = actor_wg.add(data)
|
| 69 |
+
|
| 70 |
+
print(output.batch["a"])
|
| 71 |
+
assert len(output) == 10, "Failed in args split and padding."
|
| 72 |
+
|
| 73 |
+
data = DataProto.from_dict({"a": torch.zeros(10)}, {"na": np.array([str(i) for i in range(10)], dtype=object)})
|
| 74 |
+
output = actor_wg.add(data=data)
|
| 75 |
+
|
| 76 |
+
print(output.batch["a"])
|
| 77 |
+
assert len(output) == 10, "Failed in kwargs split and padding."
|
| 78 |
+
|
| 79 |
+
data = DataProto.from_dict({"a": torch.zeros(1)}, {"na": np.array([str(i) for i in range(1)], dtype=object)})
|
| 80 |
+
output = actor_wg.add(data)
|
| 81 |
+
|
| 82 |
+
print(output.batch["a"])
|
| 83 |
+
assert len(output) == 1, "Failed in args split and padding."
|
| 84 |
+
|
| 85 |
+
data = DataProto.from_dict({"a": torch.zeros(1)}, {"na": np.array([str(i) for i in range(1)], dtype=object)})
|
| 86 |
+
output = actor_wg.add(data=data)
|
| 87 |
+
|
| 88 |
+
print(output.batch["a"])
|
| 89 |
+
assert len(output) == 1, "Failed in kwargs split and padding."
|
| 90 |
+
|
| 91 |
+
data = DataProto.from_dict({"a": torch.zeros(8)}, {"na": np.array([str(i) for i in range(8)], dtype=object)})
|
| 92 |
+
output = actor_wg.add(data)
|
| 93 |
+
|
| 94 |
+
print(output.batch["a"])
|
| 95 |
+
assert len(output) == 8, "Failed in args split and padding."
|
| 96 |
+
|
| 97 |
+
data = DataProto.from_dict({"a": torch.zeros(8)}, {"na": np.array([str(i) for i in range(8)], dtype=object)})
|
| 98 |
+
output = actor_wg.add(data=data)
|
| 99 |
+
|
| 100 |
+
print(output.batch["a"])
|
| 101 |
+
assert len(output) == 8, "Failed in kwargs split and padding."
|
| 102 |
+
|
| 103 |
+
# test data proto specific config
|
| 104 |
+
DataProtoConfig.auto_padding = False
|
| 105 |
+
|
| 106 |
+
data = DataProto.from_dict(
|
| 107 |
+
{"a": torch.zeros(10)}, {"na": np.array([str(i) for i in range(10)], dtype=object)}, auto_padding=True
|
| 108 |
+
)
|
| 109 |
+
output = actor_wg.add(data)
|
| 110 |
+
print(output.batch["a"])
|
| 111 |
+
assert len(output) == 10, "Failed in args split and padding."
|
| 112 |
+
|
| 113 |
+
data = DataProto.from_dict(
|
| 114 |
+
{"a": torch.zeros(10)}, {"na": np.array([str(i) for i in range(10)], dtype=object)}, auto_padding=True
|
| 115 |
+
)
|
| 116 |
+
output = actor_wg.add(data=data)
|
| 117 |
+
print(output.batch["a"])
|
| 118 |
+
assert len(output) == 10, "Failed in kwargs split and padding."
|
| 119 |
+
|
| 120 |
+
data = DataProto.from_single_dict(
|
| 121 |
+
{"a": torch.zeros(1), "na": np.array([str(i) for i in range(1)], dtype=object)}, auto_padding=True
|
| 122 |
+
)
|
| 123 |
+
output = actor_wg.add(data)
|
| 124 |
+
|
| 125 |
+
print(output.batch["a"])
|
| 126 |
+
assert len(output) == 1, "Failed in args split and padding."
|
| 127 |
+
|
| 128 |
+
data = DataProto.from_single_dict(
|
| 129 |
+
{"a": torch.zeros(1), "na": np.array([str(i) for i in range(1)], dtype=object)}, auto_padding=True
|
| 130 |
+
)
|
| 131 |
+
output = actor_wg.add(data=data)
|
| 132 |
+
|
| 133 |
+
print(output.batch["a"])
|
| 134 |
+
assert len(output) == 1, "Failed in kwargs split and padding."
|
| 135 |
+
|
| 136 |
+
data = DataProto.from_single_dict({"a": torch.zeros(8), "na": np.array([str(i) for i in range(8)], dtype=object)})
|
| 137 |
+
output = actor_wg.add(data)
|
| 138 |
+
|
| 139 |
+
print(output.batch["a"])
|
| 140 |
+
assert len(output) == 8, "Failed in args split and padding."
|
| 141 |
+
|
| 142 |
+
data = DataProto.from_single_dict({"a": torch.zeros(8), "na": np.array([str(i) for i in range(8)], dtype=object)})
|
| 143 |
+
output = actor_wg.add(data=data)
|
| 144 |
+
|
| 145 |
+
print(output.batch["a"])
|
| 146 |
+
assert len(output) == 8, "Failed in kwargs split and padding."
|
| 147 |
+
|
| 148 |
+
ray.shutdown()
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
if __name__ == "__main__":
|
| 152 |
+
test_auto_padding()
|
code/RL_model/verl/verl_train/tests/single_controller/test_colocated_workers.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import ray
|
| 16 |
+
|
| 17 |
+
from verl import DataProto
|
| 18 |
+
from verl.single_controller.base import Worker
|
| 19 |
+
from verl.single_controller.base.decorator import Dispatch, register
|
| 20 |
+
from verl.single_controller.ray.base import (
|
| 21 |
+
RayClassWithInitArgs,
|
| 22 |
+
RayResourcePool,
|
| 23 |
+
RayWorkerGroup,
|
| 24 |
+
create_colocated_worker_cls,
|
| 25 |
+
)
|
| 26 |
+
from verl.utils.device import get_device_name
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@ray.remote
|
| 30 |
+
class Actor(Worker):
|
| 31 |
+
def __init__(self) -> None:
|
| 32 |
+
super().__init__()
|
| 33 |
+
|
| 34 |
+
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
|
| 35 |
+
def add(self, data: DataProto):
|
| 36 |
+
data.batch["a"] += self.rank
|
| 37 |
+
return data
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@ray.remote
|
| 41 |
+
class Critic(Worker):
|
| 42 |
+
def __init__(self, config) -> None:
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.config = config
|
| 45 |
+
|
| 46 |
+
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
|
| 47 |
+
async def sub(self, data: DataProto):
|
| 48 |
+
data.batch["a"] -= self.config["b"]
|
| 49 |
+
return data
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def test_colocated_workers():
|
| 53 |
+
ray.init()
|
| 54 |
+
|
| 55 |
+
import torch
|
| 56 |
+
|
| 57 |
+
data = DataProto.from_dict({"a": torch.zeros(10)})
|
| 58 |
+
# create separate workers on the same resource pool
|
| 59 |
+
actor_cls = RayClassWithInitArgs(cls=Actor)
|
| 60 |
+
critic_cls = RayClassWithInitArgs(cls=Critic, config={"b": 10})
|
| 61 |
+
resource_pool = RayResourcePool(process_on_nodes=[2])
|
| 62 |
+
|
| 63 |
+
actor_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=actor_cls, device_name=get_device_name())
|
| 64 |
+
critic_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=critic_cls, device_name=get_device_name())
|
| 65 |
+
|
| 66 |
+
expected_actor_output = actor_wg.add(data)
|
| 67 |
+
expected_critic_output = critic_wg.sub(data)
|
| 68 |
+
|
| 69 |
+
# create colocated workers
|
| 70 |
+
cls_dict = {"actor": actor_cls, "critic": critic_cls}
|
| 71 |
+
ray_cls_with_init = create_colocated_worker_cls(cls_dict)
|
| 72 |
+
wg_dict = RayWorkerGroup(
|
| 73 |
+
resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, device_name=get_device_name()
|
| 74 |
+
)
|
| 75 |
+
spawn_wg = wg_dict.spawn(prefix_set=cls_dict.keys())
|
| 76 |
+
|
| 77 |
+
colocated_actor_wg = spawn_wg["actor"]
|
| 78 |
+
colocated_critic_wg = spawn_wg["critic"]
|
| 79 |
+
|
| 80 |
+
actor_output = colocated_actor_wg.add(data)
|
| 81 |
+
critic_output = colocated_critic_wg.sub(data)
|
| 82 |
+
|
| 83 |
+
torch.testing.assert_close(expected_actor_output.batch, actor_output.batch, atol=0, rtol=0)
|
| 84 |
+
torch.testing.assert_close(expected_critic_output.batch, critic_output.batch, atol=0, rtol=0)
|
| 85 |
+
|
| 86 |
+
ray.shutdown()
|
code/RL_model/verl/verl_train/tests/single_controller/test_colocated_workers_fused.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import ray
|
| 16 |
+
|
| 17 |
+
from verl import DataProto
|
| 18 |
+
from verl.single_controller.base import Worker
|
| 19 |
+
from verl.single_controller.base.decorator import Dispatch, register
|
| 20 |
+
from verl.single_controller.ray.base import (
|
| 21 |
+
RayClassWithInitArgs,
|
| 22 |
+
RayResourcePool,
|
| 23 |
+
RayWorkerGroup,
|
| 24 |
+
create_colocated_worker_cls_fused,
|
| 25 |
+
)
|
| 26 |
+
from verl.utils.device import get_device_name
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@ray.remote
|
| 30 |
+
class Actor(Worker):
|
| 31 |
+
def __init__(self) -> None:
|
| 32 |
+
super().__init__()
|
| 33 |
+
|
| 34 |
+
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
|
| 35 |
+
def add(self, data: DataProto):
|
| 36 |
+
data.batch["a"] += self.rank
|
| 37 |
+
return data
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@ray.remote
|
| 41 |
+
class Critic(Worker):
|
| 42 |
+
def __init__(self, config) -> None:
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.config = config
|
| 45 |
+
|
| 46 |
+
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
|
| 47 |
+
def sub(self, data: DataProto):
|
| 48 |
+
data.batch["a"] -= self.config["b"]
|
| 49 |
+
return data
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def test_colocated_workers_fused():
|
| 53 |
+
ray.init()
|
| 54 |
+
|
| 55 |
+
import torch
|
| 56 |
+
|
| 57 |
+
data = DataProto.from_dict({"a": torch.zeros(10)})
|
| 58 |
+
# create separate workers on the same resource pool
|
| 59 |
+
actor_cls = RayClassWithInitArgs(cls=Actor)
|
| 60 |
+
critic_cls = RayClassWithInitArgs(cls=Critic, config={"b": 10})
|
| 61 |
+
resource_pool = RayResourcePool(process_on_nodes=[2])
|
| 62 |
+
|
| 63 |
+
actor_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=actor_cls, device_name=get_device_name())
|
| 64 |
+
critic_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=critic_cls, device_name=get_device_name())
|
| 65 |
+
|
| 66 |
+
expected_actor_output = actor_wg.add(data)
|
| 67 |
+
expected_critic_output = critic_wg.sub(data)
|
| 68 |
+
|
| 69 |
+
# create colocated workers
|
| 70 |
+
cls_dict = {"actor": actor_cls, "critic": critic_cls}
|
| 71 |
+
ray_cls_with_init = create_colocated_worker_cls_fused(cls_dict)
|
| 72 |
+
wg_dict = RayWorkerGroup(
|
| 73 |
+
resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, device_name=get_device_name()
|
| 74 |
+
)
|
| 75 |
+
spawn_wg = wg_dict.spawn(prefix_set=cls_dict.keys())
|
| 76 |
+
|
| 77 |
+
colocated_actor_wg = spawn_wg["actor"]
|
| 78 |
+
colocated_critic_wg = spawn_wg["critic"]
|
| 79 |
+
|
| 80 |
+
actor_output = colocated_actor_wg.add(data)
|
| 81 |
+
critic_output = colocated_critic_wg.sub(data)
|
| 82 |
+
|
| 83 |
+
torch.testing.assert_close(expected_actor_output.batch, actor_output.batch, atol=0, rtol=0)
|
| 84 |
+
torch.testing.assert_close(expected_critic_output.batch, critic_output.batch, atol=0, rtol=0)
|
| 85 |
+
|
| 86 |
+
ray.shutdown()
|
code/RL_model/verl/verl_train/tests/single_controller/test_data_transfer.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
In this test, we instantiate a data parallel worker with 8 GPUs
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import ray
|
| 19 |
+
import tensordict
|
| 20 |
+
import torch
|
| 21 |
+
from codetiming import Timer
|
| 22 |
+
from packaging import version
|
| 23 |
+
from torch import distributed as dist
|
| 24 |
+
|
| 25 |
+
from verl import DataProto
|
| 26 |
+
from verl.single_controller.base import Worker
|
| 27 |
+
from verl.single_controller.base.decorator import Dispatch, register
|
| 28 |
+
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
|
| 29 |
+
from verl.utils.device import get_device_name
|
| 30 |
+
from verl.utils.ray_utils import parallel_put
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@ray.remote
|
| 34 |
+
class DummyWorker(Worker):
|
| 35 |
+
def __init__(self):
|
| 36 |
+
super().__init__()
|
| 37 |
+
dist.init_process_group()
|
| 38 |
+
|
| 39 |
+
@register(dispatch_mode=Dispatch.DP_COMPUTE, blocking=False)
|
| 40 |
+
def do_nothing(self, data):
|
| 41 |
+
for key in data.batch.keys():
|
| 42 |
+
data.batch[key] += 1
|
| 43 |
+
if version.parse(tensordict.__version__) >= version.parse("0.5.0"):
|
| 44 |
+
data.batch = data.batch.consolidate()
|
| 45 |
+
return data
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def test_data_transfer():
|
| 49 |
+
ray.init()
|
| 50 |
+
# construct resource pool
|
| 51 |
+
resource_pool = RayResourcePool([8])
|
| 52 |
+
cls_with_init = RayClassWithInitArgs(cls=DummyWorker)
|
| 53 |
+
# construct worker group
|
| 54 |
+
wg = RayWorkerGroup(resource_pool, cls_with_init, device_name=get_device_name())
|
| 55 |
+
|
| 56 |
+
# this is real dataset size
|
| 57 |
+
batch_size = 4096
|
| 58 |
+
seqlen = 32768
|
| 59 |
+
|
| 60 |
+
data_dict = {}
|
| 61 |
+
|
| 62 |
+
for i in range(2):
|
| 63 |
+
data_dict[str(i)] = torch.randint(0, 10000, (batch_size, seqlen))
|
| 64 |
+
|
| 65 |
+
data = DataProto.from_dict(tensors=data_dict)
|
| 66 |
+
|
| 67 |
+
print(data)
|
| 68 |
+
|
| 69 |
+
# we manually split data here and send to each worker
|
| 70 |
+
data_list = data.chunk(wg.world_size)
|
| 71 |
+
|
| 72 |
+
for i in range(wg.world_size):
|
| 73 |
+
# consolidate is necessary
|
| 74 |
+
if version.parse(tensordict.__version__) >= version.parse("0.5.0"):
|
| 75 |
+
data_list[i].batch = data_list[i].batch.consolidate()
|
| 76 |
+
|
| 77 |
+
with Timer(name="ray.pickle", initial_text=True):
|
| 78 |
+
for i in range(wg.world_size):
|
| 79 |
+
ray.cloudpickle.pickle.dumps(data_list[i])
|
| 80 |
+
|
| 81 |
+
with Timer(name="raw.pickle", initial_text=True):
|
| 82 |
+
import pickle
|
| 83 |
+
|
| 84 |
+
for i in range(wg.world_size):
|
| 85 |
+
pickle.dumps(data_list[i])
|
| 86 |
+
|
| 87 |
+
# we put in advance
|
| 88 |
+
with Timer(name="put", initial_text=True):
|
| 89 |
+
# takes around 40 seconds
|
| 90 |
+
data_list_ref = parallel_put(data_list)
|
| 91 |
+
# for i in range(wg.world_size):
|
| 92 |
+
# data_list[i] = ray.put(data_list[i])
|
| 93 |
+
|
| 94 |
+
with Timer(name="launch", initial_text=True):
|
| 95 |
+
output_ref = wg.do_nothing(data_list_ref)
|
| 96 |
+
|
| 97 |
+
with Timer(name="get", initial_text=True):
|
| 98 |
+
# takes around 40 seconds
|
| 99 |
+
output_lst = ray.get(output_ref)
|
| 100 |
+
|
| 101 |
+
for input_data, output_data in zip(data_list, output_lst, strict=True):
|
| 102 |
+
for key in input_data.batch.keys():
|
| 103 |
+
assert torch.all(torch.eq(input_data.batch[key] + 1, output_data.batch[key])), (
|
| 104 |
+
input_data.batch[key],
|
| 105 |
+
output_data.batch[key],
|
| 106 |
+
key,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
ray.shutdown()
|
code/RL_model/verl/verl_train/tests/single_controller/test_decorator_on_cpu.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import asyncio
|
| 16 |
+
import time
|
| 17 |
+
|
| 18 |
+
import pytest
|
| 19 |
+
import ray
|
| 20 |
+
import torch
|
| 21 |
+
from tensordict import TensorDict
|
| 22 |
+
|
| 23 |
+
from verl.protocol import DataProto, DataProtoFuture
|
| 24 |
+
from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register
|
| 25 |
+
from verl.single_controller.base.worker import Worker
|
| 26 |
+
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
|
| 27 |
+
from verl.utils import tensordict_utils as tu
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# Pytest fixture for Ray setup/teardown
|
| 31 |
+
@pytest.fixture
|
| 32 |
+
def ray_init_shutdown():
|
| 33 |
+
ray.init(num_cpus=100)
|
| 34 |
+
yield
|
| 35 |
+
ray.shutdown()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# Define a simple worker for testing
|
| 39 |
+
@ray.remote
|
| 40 |
+
class DecoratorTestWorker(Worker):
|
| 41 |
+
def __init__(self, initial_value=0):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.value = initial_value
|
| 44 |
+
# Simulate some setup if needed
|
| 45 |
+
time.sleep(0.1) # Ensure worker init completes
|
| 46 |
+
|
| 47 |
+
self._register_dispatch_collect_info(mesh_name="train", dp_rank=self.rank, is_collect=True)
|
| 48 |
+
|
| 49 |
+
# Test method for synchronous DP compute (default behavior)
|
| 50 |
+
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
|
| 51 |
+
def dp_compute(self, data: DataProto) -> DataProto:
|
| 52 |
+
time.sleep(0.1) # Simulate work
|
| 53 |
+
rank_value = torch.tensor(self.rank, device=data.batch["input"].device, dtype=data.batch["input"].dtype)
|
| 54 |
+
data.batch["output"] = data.batch["input"] + self.value + rank_value
|
| 55 |
+
return data
|
| 56 |
+
|
| 57 |
+
# Test async def method with DP compute (default behavior)
|
| 58 |
+
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO, blocking=False)
|
| 59 |
+
async def async_dp_compute(self, data: DataProto) -> DataProto:
|
| 60 |
+
# Simulate async work
|
| 61 |
+
await asyncio.sleep(0.1) # Simulate async work
|
| 62 |
+
rank_value = torch.tensor(self.rank, device=data.batch["input"].device, dtype=data.batch["input"].dtype)
|
| 63 |
+
data.batch["output_async"] = data.batch["input"] * 2 + self.value + rank_value
|
| 64 |
+
return data
|
| 65 |
+
|
| 66 |
+
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="train"), blocking=False)
|
| 67 |
+
def dp_compute_td(self, data: TensorDict) -> TensorDict:
|
| 68 |
+
# note that we have to call contiguous so that we can modify data in plac
|
| 69 |
+
data = tu.contiguous(data)
|
| 70 |
+
rank_value = torch.tensor(self.rank, device=data["input"].device, dtype=data["input"].dtype)
|
| 71 |
+
data["output"] = data["input"] + self.value + rank_value
|
| 72 |
+
position_ids = data.pop("position_ids")
|
| 73 |
+
position_ids._ragged_idx = 2
|
| 74 |
+
|
| 75 |
+
for i, position_id in enumerate(position_ids.unbind(dim=0)):
|
| 76 |
+
assert (position_id == torch.arange(4 + rank_value * 2 + i).expand(position_id.shape)).all()
|
| 77 |
+
|
| 78 |
+
return data
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# Test function for synchronous DP compute
|
| 82 |
+
def test_decorator_dp_compute(ray_init_shutdown):
|
| 83 |
+
"""
|
| 84 |
+
Tests the default behavior of a synchronous decorated method with DP_COMPUTE_PROTO.
|
| 85 |
+
Verifies the result correctness.
|
| 86 |
+
"""
|
| 87 |
+
num_workers = 2
|
| 88 |
+
resource_pool = RayResourcePool([num_workers], use_gpu=False, max_colocate_count=1) # Use CPU for simplicity
|
| 89 |
+
cls_with_args = RayClassWithInitArgs(cls=DecoratorTestWorker, initial_value=10)
|
| 90 |
+
worker_group = RayWorkerGroup(
|
| 91 |
+
resource_pool, cls_with_args, name_prefix=f"decorator_test_sync_dp_{int(time.time())}"
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# Prepare input data (size 4, for 2 workers)
|
| 95 |
+
input_tensor = torch.arange(4, dtype=torch.float32)
|
| 96 |
+
data = DataProto(batch=TensorDict({"input": input_tensor}, batch_size=[4]))
|
| 97 |
+
|
| 98 |
+
# Call the decorated method
|
| 99 |
+
output = worker_group.dp_compute(data)
|
| 100 |
+
|
| 101 |
+
# Assert the result correctness
|
| 102 |
+
assert isinstance(output, DataProto), "Expected DataProto result"
|
| 103 |
+
assert "output" in output.batch.keys()
|
| 104 |
+
assert len(output) == len(data), "Output length should match input length"
|
| 105 |
+
|
| 106 |
+
# Expected output calculation for DP_COMPUTE_PROTO with 2 workers
|
| 107 |
+
# Worker 0 gets data[0:2], Worker 1 gets data[2:4]
|
| 108 |
+
# Worker 0 adds initial_value(10) + rank(0) = 10
|
| 109 |
+
# Worker 1 adds initial_value(10) + rank(1) = 11
|
| 110 |
+
expected_output_part1 = torch.tensor([0, 1], dtype=torch.float32) + 10 + 0
|
| 111 |
+
expected_output_part2 = torch.tensor([2, 3], dtype=torch.float32) + 10 + 1
|
| 112 |
+
expected_output = torch.cat([expected_output_part1, expected_output_part2])
|
| 113 |
+
|
| 114 |
+
torch.testing.assert_close(output.batch["output"], expected_output, msg="Sync DP compute output data mismatch")
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# Test function for async def method with DP compute
|
| 118 |
+
def test_decorator_async_function(ray_init_shutdown):
|
| 119 |
+
"""
|
| 120 |
+
Tests the decorator with an `async def` method using DP_COMPUTE_PROTO.
|
| 121 |
+
Verifies that the call returns a future and the result is correct after .get().
|
| 122 |
+
"""
|
| 123 |
+
num_workers = 2
|
| 124 |
+
resource_pool = RayResourcePool([num_workers], use_gpu=False, max_colocate_count=1)
|
| 125 |
+
cls_with_args = RayClassWithInitArgs(cls=DecoratorTestWorker, initial_value=5)
|
| 126 |
+
worker_group = RayWorkerGroup(
|
| 127 |
+
resource_pool, cls_with_args, name_prefix=f"decorator_test_async_dp_{int(time.time())}"
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# Prepare input data (size 4, for 2 workers)
|
| 131 |
+
input_tensor = torch.arange(4, dtype=torch.float32)
|
| 132 |
+
data = DataProto(batch=TensorDict({"input": input_tensor}, batch_size=[4]))
|
| 133 |
+
|
| 134 |
+
# Call the async decorated method - this should return a future
|
| 135 |
+
future_output: DataProtoFuture = worker_group.async_dp_compute(data)
|
| 136 |
+
|
| 137 |
+
# Assert that the call returned a future
|
| 138 |
+
assert isinstance(future_output, DataProtoFuture), "Expected DataProtoFuture for async def call"
|
| 139 |
+
|
| 140 |
+
# Get the result (this should block)
|
| 141 |
+
result_data = future_output.get()
|
| 142 |
+
|
| 143 |
+
# Assert the result correctness
|
| 144 |
+
assert isinstance(result_data, DataProto)
|
| 145 |
+
assert "output_async" in result_data.batch.keys()
|
| 146 |
+
assert len(result_data) == len(data), "Output length should match input length"
|
| 147 |
+
|
| 148 |
+
# Expected output calculation for DP_COMPUTE_PROTO with 2 workers
|
| 149 |
+
# Worker 0 gets data[0:2], Worker 1 gets data[2:4]
|
| 150 |
+
# Worker 0 calculates: input * 2 + initial_value(5) + rank(0)
|
| 151 |
+
# Worker 1 calculates: input * 2 + initial_value(5) + rank(1)
|
| 152 |
+
expected_output_part1 = (torch.tensor([0, 1], dtype=torch.float32) * 2) + 5 + 0
|
| 153 |
+
expected_output_part2 = (torch.tensor([2, 3], dtype=torch.float32) * 2) + 5 + 1
|
| 154 |
+
expected_output = torch.cat([expected_output_part1, expected_output_part2])
|
| 155 |
+
|
| 156 |
+
torch.testing.assert_close(
|
| 157 |
+
result_data.batch["output_async"], expected_output, msg="Async DP compute output data mismatch"
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def test_decorator_dp_compute_td(ray_init_shutdown):
|
| 162 |
+
num_workers = 2
|
| 163 |
+
resource_pool = RayResourcePool([num_workers], use_gpu=False, max_colocate_count=1) # Use CPU for simplicity
|
| 164 |
+
cls_with_args = RayClassWithInitArgs(cls=DecoratorTestWorker, initial_value=10)
|
| 165 |
+
worker_group = RayWorkerGroup(
|
| 166 |
+
resource_pool, cls_with_args, name_prefix=f"decorator_test_sync_dp_{int(time.time())}"
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# Prepare input data (size 4, for 2 workers)
|
| 170 |
+
input_tensor = torch.arange(4, dtype=torch.float32)
|
| 171 |
+
position_ids = torch.nested.as_nested_tensor(
|
| 172 |
+
[
|
| 173 |
+
torch.arange(4).expand(4, 4).contiguous(),
|
| 174 |
+
torch.arange(5).expand(4, 5).contiguous(),
|
| 175 |
+
torch.arange(6).expand(4, 6).contiguous(),
|
| 176 |
+
torch.arange(7).expand(4, 7).contiguous(),
|
| 177 |
+
],
|
| 178 |
+
layout=torch.jagged,
|
| 179 |
+
)
|
| 180 |
+
data = TensorDict({"input": input_tensor, "position_ids": position_ids}, batch_size=[4])
|
| 181 |
+
|
| 182 |
+
# Call the decorated method
|
| 183 |
+
output = worker_group.dp_compute_td(data)
|
| 184 |
+
|
| 185 |
+
output = output.get()
|
| 186 |
+
|
| 187 |
+
# Assert the result correctness
|
| 188 |
+
assert isinstance(output, TensorDict), "Expected DataProto result"
|
| 189 |
+
assert "output" in output.keys()
|
| 190 |
+
assert len(output) == len(data), "Output length should match input length"
|
| 191 |
+
|
| 192 |
+
# Expected output calculation for DP_COMPUTE_PROTO with 2 workers
|
| 193 |
+
# Worker 0 gets data[0:2], Worker 1 gets data[2:4]
|
| 194 |
+
# Worker 0 adds initial_value(10) + rank(0) = 10
|
| 195 |
+
# Worker 1 adds initial_value(10) + rank(1) = 11
|
| 196 |
+
expected_output_part1 = torch.tensor([0, 1], dtype=torch.float32) + 10 + 0
|
| 197 |
+
expected_output_part2 = torch.tensor([2, 3], dtype=torch.float32) + 10 + 1
|
| 198 |
+
expected_output = torch.cat([expected_output_part1, expected_output_part2])
|
| 199 |
+
|
| 200 |
+
torch.testing.assert_close(output["output"], expected_output, msg="Sync DP compute output data mismatch")
|
code/RL_model/verl/verl_train/tests/single_controller/test_device_mesh_register.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import ray
|
| 18 |
+
import torch
|
| 19 |
+
from tensordict import TensorDict
|
| 20 |
+
|
| 21 |
+
import verl.utils.tensordict_utils as tu
|
| 22 |
+
from verl import DataProto
|
| 23 |
+
from verl.single_controller.base import Worker
|
| 24 |
+
from verl.single_controller.base.decorator import make_nd_compute_dataproto_dispatch_fn, register
|
| 25 |
+
from verl.utils.device import get_device_name, get_nccl_backend
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@ray.remote
|
| 29 |
+
class TestActor(Worker):
|
| 30 |
+
def __init__(self):
|
| 31 |
+
super().__init__()
|
| 32 |
+
|
| 33 |
+
import torch.distributed
|
| 34 |
+
|
| 35 |
+
torch.distributed.init_process_group(backend=get_nccl_backend())
|
| 36 |
+
self.infer_device_mesh = torch.distributed.device_mesh.init_device_mesh(
|
| 37 |
+
device_type=get_device_name(), mesh_shape=[2, 4], mesh_dim_names=["dp", "tp"]
|
| 38 |
+
)
|
| 39 |
+
self.train_device_mesh = torch.distributed.device_mesh.init_device_mesh(
|
| 40 |
+
device_type=get_device_name(), mesh_shape=[2, 2, 2], mesh_dim_names=["pp", "dp", "tp"]
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
self._register_dispatch_collect_info(
|
| 44 |
+
"infer",
|
| 45 |
+
dp_rank=self.infer_device_mesh["dp"].get_local_rank(),
|
| 46 |
+
is_collect=self.infer_device_mesh["tp"].get_local_rank() == 0,
|
| 47 |
+
)
|
| 48 |
+
self._register_dispatch_collect_info(
|
| 49 |
+
"train",
|
| 50 |
+
dp_rank=self.train_device_mesh["dp"].get_local_rank(),
|
| 51 |
+
is_collect=self.train_device_mesh["tp"].get_local_rank() == 0
|
| 52 |
+
and self.train_device_mesh["pp"].get_local_rank() == 1,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="infer"))
|
| 56 |
+
def generate_data_proto(self, data: DataProto):
|
| 57 |
+
tp_rank = self.infer_device_mesh["tp"].get_local_rank()
|
| 58 |
+
dp_rank = self.infer_device_mesh["dp"].get_local_rank()
|
| 59 |
+
data.batch["a"] += (tp_rank + 1) * dp_rank
|
| 60 |
+
return data
|
| 61 |
+
|
| 62 |
+
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="infer"))
|
| 63 |
+
def generate_tensordict(self, data: TensorDict):
|
| 64 |
+
tp_rank = self.infer_device_mesh["tp"].get_local_rank()
|
| 65 |
+
dp_rank = self.infer_device_mesh["dp"].get_local_rank()
|
| 66 |
+
data["a"] += (tp_rank + 1) * dp_rank
|
| 67 |
+
return data
|
| 68 |
+
|
| 69 |
+
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="train"))
|
| 70 |
+
def train_data_proto(self, data: DataProto):
|
| 71 |
+
tp_rank = self.train_device_mesh["tp"].get_local_rank()
|
| 72 |
+
dp_rank = self.train_device_mesh["dp"].get_local_rank()
|
| 73 |
+
pp_rank = self.train_device_mesh["pp"].get_local_rank()
|
| 74 |
+
data.batch["a"] += (tp_rank + 1) * (dp_rank + 2) * (pp_rank + 3)
|
| 75 |
+
# tp rank 0, pp rank 1, dp rank 0, output data added: 8 + 3 = 11
|
| 76 |
+
# tp rank 0, pp rank 1, dp rank 1, output data added: 12 + 4 = 16
|
| 77 |
+
return data
|
| 78 |
+
|
| 79 |
+
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="train"))
|
| 80 |
+
def train_tensordict(self, data: TensorDict):
|
| 81 |
+
tp_rank = self.train_device_mesh["tp"].get_local_rank()
|
| 82 |
+
dp_rank = self.train_device_mesh["dp"].get_local_rank()
|
| 83 |
+
pp_rank = self.train_device_mesh["pp"].get_local_rank()
|
| 84 |
+
data["a"] += (tp_rank + 1) * (dp_rank + 2) * (pp_rank + 3)
|
| 85 |
+
# tp rank 0, pp rank 1, dp rank 0, output data added: 8 + 3 = 11
|
| 86 |
+
# tp rank 0, pp rank 1, dp rank 1, output data added: 12 + 4 = 16
|
| 87 |
+
return data
|
| 88 |
+
|
| 89 |
+
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="infer"))
|
| 90 |
+
def generate_nested_tensor(self, data: TensorDict):
|
| 91 |
+
tp_rank = self.infer_device_mesh["tp"].get_local_rank()
|
| 92 |
+
dp_rank = self.infer_device_mesh["dp"].get_local_rank()
|
| 93 |
+
assert data.shape[0] == 8
|
| 94 |
+
data["input_ids"] += tp_rank + dp_rank
|
| 95 |
+
|
| 96 |
+
print(data)
|
| 97 |
+
return data
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def test_dist_global_info_wg():
|
| 101 |
+
# create a worker group with size 8
|
| 102 |
+
# register a infer dist info with tp=4, dp=2
|
| 103 |
+
# register a train dist info with tp=2, dp=2, pp=2
|
| 104 |
+
# test the correctness of data dispatch and computation
|
| 105 |
+
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
|
| 106 |
+
|
| 107 |
+
ray.init()
|
| 108 |
+
|
| 109 |
+
ray_cls = RayClassWithInitArgs(TestActor)
|
| 110 |
+
resource_pool = RayResourcePool(process_on_nodes=[8])
|
| 111 |
+
wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls, device_name=get_device_name())
|
| 112 |
+
|
| 113 |
+
infer_input_data_proto = DataProto.from_single_dict(data={"a": torch.tensor([1, 2])})
|
| 114 |
+
infer_output_data_proto = wg.generate_data_proto(infer_input_data_proto)
|
| 115 |
+
|
| 116 |
+
assert wg._dispatch_info["infer"] == [0, 0, 0, 0, 1, 1, 1, 1]
|
| 117 |
+
|
| 118 |
+
assert torch.all(torch.eq(infer_output_data_proto.batch["a"], torch.tensor([1, 3])))
|
| 119 |
+
|
| 120 |
+
infer_input_tensordict = infer_input_data_proto.to_tensordict()
|
| 121 |
+
infer_output_tensordict = wg.generate_tensordict(infer_input_tensordict)
|
| 122 |
+
assert torch.all(torch.eq(infer_output_tensordict["a"], torch.tensor([1, 3])))
|
| 123 |
+
|
| 124 |
+
train_input_data_proto = DataProto.from_single_dict(data={"a": torch.tensor([3, 4])})
|
| 125 |
+
train_output_data_proto = wg.train_data_proto(train_input_data_proto)
|
| 126 |
+
|
| 127 |
+
assert wg._dispatch_info["train"] == [0, 0, 1, 1, 0, 0, 1, 1]
|
| 128 |
+
|
| 129 |
+
assert torch.all(torch.eq(train_output_data_proto.batch["a"], torch.tensor([11, 16])))
|
| 130 |
+
|
| 131 |
+
train_input_tensordict = train_input_data_proto.to_tensordict()
|
| 132 |
+
train_output_tensordict = wg.train_tensordict(train_input_tensordict)
|
| 133 |
+
assert torch.all(torch.eq(train_output_tensordict["a"], torch.tensor([11, 16])))
|
| 134 |
+
|
| 135 |
+
# create a batch size of input_ids
|
| 136 |
+
input_ids = [
|
| 137 |
+
torch.randint(low=0, high=128, size=(np.random.randint(low=1, high=10, dtype=np.int64),)) for _ in range(16)
|
| 138 |
+
]
|
| 139 |
+
input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged)
|
| 140 |
+
data = tu.get_tensordict(tensor_dict={"input_ids": input_ids})
|
| 141 |
+
output = wg.generate_nested_tensor(data)
|
| 142 |
+
|
| 143 |
+
input_ids_chunked = list(input_ids.chunk(2))
|
| 144 |
+
|
| 145 |
+
print(input_ids_chunked)
|
| 146 |
+
|
| 147 |
+
input_ids_chunked[0] += 0
|
| 148 |
+
input_ids_chunked[1] += 1
|
| 149 |
+
|
| 150 |
+
expected = tu.concat_nested_tensors(input_ids_chunked)
|
| 151 |
+
|
| 152 |
+
assert torch.all(torch.eq(output["input_ids"].values(), expected.values()))
|
| 153 |
+
|
| 154 |
+
ray.shutdown()
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
if __name__ == "__main__":
|
| 158 |
+
test_dist_global_info_wg()
|
code/RL_model/verl/verl_train/tests/single_controller/test_driverfunc_to_worker.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
import ray
|
| 18 |
+
import torch
|
| 19 |
+
from tensordict import TensorDict
|
| 20 |
+
|
| 21 |
+
from verl import DataProto
|
| 22 |
+
from verl.single_controller.base.worker import Worker
|
| 23 |
+
from verl.single_controller.ray import RayWorkerGroup
|
| 24 |
+
from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool
|
| 25 |
+
from verl.utils.device import get_device_name
|
| 26 |
+
|
| 27 |
+
os.environ["RAY_DEDUP_LOGS"] = "0"
|
| 28 |
+
os.environ["NCCL_DEBUG"] = "WARN"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@ray.remote
|
| 32 |
+
class ModelActor(Worker):
|
| 33 |
+
def __init__(self):
|
| 34 |
+
pass
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class HackSelf:
|
| 38 |
+
def __init__(self):
|
| 39 |
+
pass
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_aux_metrics(self, test_proto):
|
| 43 |
+
sequence_ids = test_proto.batch["sequence_ids"]
|
| 44 |
+
decode_count = []
|
| 45 |
+
for i in range(sequence_ids.size(0)):
|
| 46 |
+
decode_count.append(len(sequence_ids[i].tolist()))
|
| 47 |
+
ret_proto = DataProto(
|
| 48 |
+
batch=TensorDict(
|
| 49 |
+
{"sequence_ids": sequence_ids, "decode_count": torch.tensor(decode_count)}, batch_size=sequence_ids.size(0)
|
| 50 |
+
)
|
| 51 |
+
)
|
| 52 |
+
return ret_proto
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def test():
|
| 56 |
+
# construct model
|
| 57 |
+
ray.init()
|
| 58 |
+
|
| 59 |
+
# create 2 workers, each hold a GPU
|
| 60 |
+
resource_pool = RayResourcePool([2], use_gpu=True, name_prefix="a")
|
| 61 |
+
|
| 62 |
+
class_with_args = RayClassWithInitArgs(cls=ModelActor)
|
| 63 |
+
shard_wg = RayWorkerGroup(resource_pool, class_with_args, device_name=get_device_name())
|
| 64 |
+
|
| 65 |
+
test_bs = 8
|
| 66 |
+
test_proto = DataProto(
|
| 67 |
+
TensorDict(
|
| 68 |
+
{
|
| 69 |
+
"sequence_ids": torch.ones([test_bs, 2048], dtype=torch.int64),
|
| 70 |
+
},
|
| 71 |
+
batch_size=test_bs,
|
| 72 |
+
),
|
| 73 |
+
meta_info={"query_length": 1536},
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Sharding among different ranks
|
| 77 |
+
ret_proto1 = shard_wg.execute_with_func_generator(get_aux_metrics, test_proto)
|
| 78 |
+
|
| 79 |
+
# compare execute on driver
|
| 80 |
+
hs = HackSelf()
|
| 81 |
+
ret_proto2 = get_aux_metrics(hs, test_proto)
|
| 82 |
+
|
| 83 |
+
torch.testing.assert_close(ret_proto1.batch["decode_count"], ret_proto2.batch["decode_count"])
|
| 84 |
+
|
| 85 |
+
ray.shutdown()
|
code/RL_model/verl/verl_train/tests/single_controller/test_fused_workers_on_cpu.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import ray
|
| 16 |
+
|
| 17 |
+
from verl.single_controller.base import Worker
|
| 18 |
+
from verl.single_controller.base.decorator import Dispatch, register
|
| 19 |
+
from verl.single_controller.ray.base import (
|
| 20 |
+
RayClassWithInitArgs,
|
| 21 |
+
RayResourcePool,
|
| 22 |
+
RayWorkerGroup,
|
| 23 |
+
create_colocated_worker_raw_cls,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@ray.remote
|
| 28 |
+
class Actor(Worker):
|
| 29 |
+
def __init__(self) -> None:
|
| 30 |
+
super().__init__()
|
| 31 |
+
|
| 32 |
+
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
|
| 33 |
+
def add(self, x):
|
| 34 |
+
x += self.rank
|
| 35 |
+
return x
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@ray.remote
|
| 39 |
+
class Critic(Worker):
|
| 40 |
+
def __init__(self, val) -> None:
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.val = val
|
| 43 |
+
|
| 44 |
+
@register(dispatch_mode=Dispatch.ALL_TO_ALL)
|
| 45 |
+
def sub(self, x):
|
| 46 |
+
x -= self.val
|
| 47 |
+
return x
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
actor_cls = RayClassWithInitArgs(cls=Actor)
|
| 51 |
+
critic_cls = RayClassWithInitArgs(cls=Critic, val=10)
|
| 52 |
+
cls_dict = {"actor": actor_cls, "critic": critic_cls}
|
| 53 |
+
FusedBaseClass = create_colocated_worker_raw_cls(cls_dict)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@ray.remote
|
| 57 |
+
class HybridWorker(FusedBaseClass):
|
| 58 |
+
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
|
| 59 |
+
def foo(self, x):
|
| 60 |
+
return self.critic.sub(self.actor.add(x))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def test_fused_workers():
|
| 64 |
+
ray.init(num_cpus=100)
|
| 65 |
+
|
| 66 |
+
# create separate workers on the same resource pool
|
| 67 |
+
process_on_nodes = [2]
|
| 68 |
+
resource_pool = RayResourcePool(process_on_nodes=process_on_nodes, use_gpu=False)
|
| 69 |
+
|
| 70 |
+
# create colocated workers
|
| 71 |
+
hybrid_cls_with_init = RayClassWithInitArgs(cls=HybridWorker)
|
| 72 |
+
hybrid_cls_with_init.fused_worker_used = True
|
| 73 |
+
|
| 74 |
+
fused_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=hybrid_cls_with_init)
|
| 75 |
+
fused_wg.fuse(cls_dict.keys())
|
| 76 |
+
|
| 77 |
+
x = fused_wg.actor.add(0.1)
|
| 78 |
+
print(x)
|
| 79 |
+
y = fused_wg.critic.sub(x)
|
| 80 |
+
print(y)
|
| 81 |
+
z = fused_wg.foo(0.1)
|
| 82 |
+
print(z)
|
| 83 |
+
for i, j in zip(y, z, strict=True):
|
| 84 |
+
assert i == j
|
| 85 |
+
|
| 86 |
+
ray.shutdown()
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
if __name__ == "__main__":
|
| 90 |
+
test_fused_workers()
|
code/RL_model/verl/verl_train/tests/single_controller/test_high_level_scheduling_api.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import gc
|
| 15 |
+
import time
|
| 16 |
+
|
| 17 |
+
import ray
|
| 18 |
+
|
| 19 |
+
from verl.single_controller.base.worker import Worker
|
| 20 |
+
from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, merge_resource_pool
|
| 21 |
+
from verl.utils.device import get_device_name
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@ray.remote
|
| 25 |
+
class TestActor(Worker):
|
| 26 |
+
# TODO: pass *args and **kwargs is bug prone and not very convincing
|
| 27 |
+
def __init__(self, cuda_visible_devices=None) -> None:
|
| 28 |
+
super().__init__(cuda_visible_devices)
|
| 29 |
+
|
| 30 |
+
def get_node_id(self):
|
| 31 |
+
return ray.get_runtime_context().get_node_id()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def test():
|
| 35 |
+
ray.init()
|
| 36 |
+
|
| 37 |
+
# test single-node-no-partition
|
| 38 |
+
print("test single-node-no-partition")
|
| 39 |
+
resource_pool = RayResourcePool([8], use_gpu=True)
|
| 40 |
+
|
| 41 |
+
class_with_args = RayClassWithInitArgs(cls=TestActor)
|
| 42 |
+
|
| 43 |
+
print("create actor worker group")
|
| 44 |
+
actor_wg = RayWorkerGroup(
|
| 45 |
+
resource_pool, class_with_args, name_prefix="high_level_api_actor", device_name=get_device_name()
|
| 46 |
+
)
|
| 47 |
+
print("create critic worker group")
|
| 48 |
+
critic_wg = RayWorkerGroup(
|
| 49 |
+
resource_pool, class_with_args, name_prefix="hight_level_api_critic", device_name=get_device_name()
|
| 50 |
+
)
|
| 51 |
+
print("create rm worker group")
|
| 52 |
+
rm_wg = RayWorkerGroup(
|
| 53 |
+
resource_pool, class_with_args, name_prefix="high_level_api_rm", device_name=get_device_name()
|
| 54 |
+
)
|
| 55 |
+
print("create ref worker group")
|
| 56 |
+
ref_wg = RayWorkerGroup(
|
| 57 |
+
resource_pool, class_with_args, name_prefix="high_level_api_ref", device_name=get_device_name()
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
assert actor_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
|
| 61 |
+
assert critic_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
|
| 62 |
+
assert rm_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
|
| 63 |
+
assert ref_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
|
| 64 |
+
|
| 65 |
+
del actor_wg
|
| 66 |
+
del critic_wg
|
| 67 |
+
del rm_wg
|
| 68 |
+
del ref_wg
|
| 69 |
+
gc.collect() # make sure ray actors are deleted
|
| 70 |
+
|
| 71 |
+
[ray.util.remove_placement_group(pg) for pg in resource_pool.get_placement_groups()]
|
| 72 |
+
print("wait 5s to remove placemeng_group")
|
| 73 |
+
time.sleep(5)
|
| 74 |
+
# test single-node-multi-partition
|
| 75 |
+
|
| 76 |
+
print("test single-node-multi-partition")
|
| 77 |
+
rm_resource_pool = RayResourcePool([4], use_gpu=True, name_prefix="rm")
|
| 78 |
+
ref_resource_pool = RayResourcePool([4], use_gpu=True, name_prefix="ref")
|
| 79 |
+
total_resource_pool = merge_resource_pool(rm_resource_pool, ref_resource_pool)
|
| 80 |
+
|
| 81 |
+
assert rm_resource_pool.world_size == 4
|
| 82 |
+
assert ref_resource_pool.world_size == 4
|
| 83 |
+
assert total_resource_pool.world_size == 8
|
| 84 |
+
|
| 85 |
+
actor_wg = RayWorkerGroup(
|
| 86 |
+
total_resource_pool, class_with_args, name_prefix="high_level_api_actor", device_name=get_device_name()
|
| 87 |
+
)
|
| 88 |
+
critic_wg = RayWorkerGroup(
|
| 89 |
+
total_resource_pool, class_with_args, name_prefix="high_level_api_critic", device_name=get_device_name()
|
| 90 |
+
)
|
| 91 |
+
rm_wg = RayWorkerGroup(
|
| 92 |
+
rm_resource_pool, class_with_args, name_prefix="high_level_api_rm", device_name=get_device_name()
|
| 93 |
+
)
|
| 94 |
+
ref_wg = RayWorkerGroup(
|
| 95 |
+
ref_resource_pool, class_with_args, name_prefix="high_level_api_ref", device_name=get_device_name()
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
assert actor_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
|
| 99 |
+
assert critic_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
|
| 100 |
+
assert rm_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(4)]
|
| 101 |
+
assert ref_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(4, 8)]
|
| 102 |
+
|
| 103 |
+
ray.shutdown()
|
code/RL_model/verl/verl_train/tests/single_controller/test_rvdz.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import ray
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@ray.remote
|
| 19 |
+
class TestWorker:
|
| 20 |
+
def __init__(self, rank, world_size, group_name):
|
| 21 |
+
self.rank = rank
|
| 22 |
+
self.world_size = world_size
|
| 23 |
+
self.group_name = group_name
|
| 24 |
+
self.communicator = None
|
| 25 |
+
|
| 26 |
+
def init(self):
|
| 27 |
+
from verl.utils.rendezvous.ray_backend import create_nccl_communicator_in_ray
|
| 28 |
+
|
| 29 |
+
self.communicator = create_nccl_communicator_in_ray(self.rank, self.world_size, self.group_name)
|
| 30 |
+
|
| 31 |
+
def test(self):
|
| 32 |
+
if self.communicator is None:
|
| 33 |
+
return None
|
| 34 |
+
return self.communicator.rank_id()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def test_rvdz():
|
| 38 |
+
ray.init()
|
| 39 |
+
|
| 40 |
+
group_name = "test_group"
|
| 41 |
+
world_size = 2
|
| 42 |
+
|
| 43 |
+
workers = [TestWorker.options(num_gpus=1).remote(rank, world_size, group_name) for rank in range(world_size)]
|
| 44 |
+
|
| 45 |
+
ray.get([worker.init.remote() for worker in workers])
|
| 46 |
+
|
| 47 |
+
ranks = ray.get([worker.test.remote() for worker in workers])
|
| 48 |
+
|
| 49 |
+
assert ranks == [0, 1], f"expecting [0, 1], got {ranks}"
|
| 50 |
+
|
| 51 |
+
ray.shutdown()
|
code/RL_model/verl/verl_train/tests/single_controller/test_worker_group_torch.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
os.environ["RAY_DEDUP_LOGS"] = "0"
|
| 18 |
+
os.environ["NCCL_DEBUG"] = "WARN"
|
| 19 |
+
|
| 20 |
+
import ray
|
| 21 |
+
import torch
|
| 22 |
+
import torch.distributed
|
| 23 |
+
|
| 24 |
+
from verl.single_controller.base.worker import Worker
|
| 25 |
+
from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
|
| 26 |
+
from verl.utils.device import get_device_name
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@ray.remote
|
| 30 |
+
class TestAllGatherActor(Worker):
|
| 31 |
+
def __init__(self, size) -> None:
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.size = size
|
| 34 |
+
|
| 35 |
+
def init(self):
|
| 36 |
+
torch.distributed.init_process_group()
|
| 37 |
+
self.tensor = torch.zeros(size=(self.size,), dtype=torch.int64, device=get_device_name())
|
| 38 |
+
self.tensor += self.rank
|
| 39 |
+
|
| 40 |
+
def all_gather(self):
|
| 41 |
+
world_size = self._world_size
|
| 42 |
+
output = torch.zeros(
|
| 43 |
+
size=(self.tensor.shape[0] * world_size,), dtype=self.tensor.dtype, device=self.tensor.device
|
| 44 |
+
)
|
| 45 |
+
torch.distributed.all_gather_into_tensor(output, self.tensor, async_op=False)
|
| 46 |
+
return output
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@ray.remote
|
| 50 |
+
class TestAllGatherActorV2(Worker):
|
| 51 |
+
def __init__(self, size) -> None:
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.size = size
|
| 54 |
+
|
| 55 |
+
torch.distributed.init_process_group()
|
| 56 |
+
self.tensor = torch.zeros(size=(self.size,), dtype=torch.int64, device=get_device_name())
|
| 57 |
+
self.tensor += self.rank
|
| 58 |
+
|
| 59 |
+
def all_gather(self):
|
| 60 |
+
world_size = self._world_size
|
| 61 |
+
output = torch.zeros(
|
| 62 |
+
size=(self.tensor.shape[0] * world_size,), dtype=self.tensor.dtype, device=self.tensor.device
|
| 63 |
+
)
|
| 64 |
+
torch.distributed.all_gather_into_tensor(output, self.tensor, async_op=False)
|
| 65 |
+
return output
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def test_all_gather_torch():
|
| 69 |
+
"""
|
| 70 |
+
In this test, we instantiate 4 GPUs in a group and test the all_gather
|
| 71 |
+
"""
|
| 72 |
+
ray.init()
|
| 73 |
+
|
| 74 |
+
# create 4 workers, each hold a GPU
|
| 75 |
+
resource_pool = RayResourcePool([4], use_gpu=True)
|
| 76 |
+
class_with_args = RayClassWithInitArgs(cls=TestAllGatherActor, size=2)
|
| 77 |
+
|
| 78 |
+
worker_group = RayWorkerGroup(
|
| 79 |
+
resource_pool, class_with_args, name_prefix="worker_group_torch", device_name=get_device_name()
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
worker_group.execute_all_sync("init")
|
| 83 |
+
output = worker_group.execute_all_sync("all_gather")
|
| 84 |
+
for i in range(1, len(output)):
|
| 85 |
+
assert torch.all(output[i] == output[0])
|
| 86 |
+
|
| 87 |
+
output = output[0].cpu()
|
| 88 |
+
print(output)
|
| 89 |
+
assert torch.all(output == torch.tensor([0, 0, 1, 1, 2, 2, 3, 3], dtype=torch.int64))
|
| 90 |
+
|
| 91 |
+
ray.shutdown()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def test_all_gather_torch_v2():
|
| 95 |
+
"""
|
| 96 |
+
In this test, we instantiate 4 GPUs in a group and test the all_gather
|
| 97 |
+
"""
|
| 98 |
+
ray.init()
|
| 99 |
+
|
| 100 |
+
# create 4 workers, each hold a GPU
|
| 101 |
+
resource_pool = RayResourcePool([4], use_gpu=True)
|
| 102 |
+
class_with_args = RayClassWithInitArgs(cls=TestAllGatherActorV2, size=2)
|
| 103 |
+
|
| 104 |
+
worker_group = RayWorkerGroup(
|
| 105 |
+
resource_pool, class_with_args, name_prefix="worker_group_torch", device_name=get_device_name()
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
output = worker_group.execute_all_sync("all_gather")
|
| 109 |
+
for i in range(1, len(output)):
|
| 110 |
+
assert torch.all(output[i] == output[0])
|
| 111 |
+
|
| 112 |
+
output = output[0].cpu()
|
| 113 |
+
print(output)
|
| 114 |
+
assert torch.all(output == torch.tensor([0, 0, 1, 1, 2, 2, 3, 3], dtype=torch.int64))
|
| 115 |
+
|
| 116 |
+
ray.shutdown()
|
code/RL_model/verl/verl_train/tests/special_e2e/README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
This folder is reserved for end-to-end tests that typically require multiple GPUs.
|
code/RL_model/verl/verl_train/tests/utils/test_activation_offload.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import os
|
| 15 |
+
import shutil
|
| 16 |
+
import tempfile
|
| 17 |
+
|
| 18 |
+
import pytest
|
| 19 |
+
import torch
|
| 20 |
+
import torch.distributed
|
| 21 |
+
import torch.multiprocessing as mp
|
| 22 |
+
from torch.distributed import init_device_mesh
|
| 23 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
| 24 |
+
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
|
| 25 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
|
| 26 |
+
|
| 27 |
+
from verl.utils.activation_offload import enable_activation_offloading
|
| 28 |
+
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
|
| 29 |
+
from verl.utils.device import get_device_name, get_nccl_backend, get_torch_device
|
| 30 |
+
from verl.utils.fsdp_utils import MixedPrecisionPolicy, apply_fsdp2, get_fsdp_wrap_policy
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def create_random_input_ids(batch_size, seq_len, vocab_size):
|
| 34 |
+
if get_device_name() == "cuda":
|
| 35 |
+
from flash_attn.bert_padding import unpad_input
|
| 36 |
+
elif get_device_name() == "npu":
|
| 37 |
+
from verl.utils.attention_utils import unpad_input
|
| 38 |
+
from verl.utils.model import compute_position_id_with_mask, create_random_mask
|
| 39 |
+
|
| 40 |
+
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=get_device_name())
|
| 41 |
+
|
| 42 |
+
attention_mask = create_random_mask(
|
| 43 |
+
input_ids, max_ratio_of_left_padding=0.1, min_ratio_of_valid_token=0.5, max_ratio_of_valid_token=0.7
|
| 44 |
+
)
|
| 45 |
+
position_ids = compute_position_id_with_mask(attention_mask)
|
| 46 |
+
|
| 47 |
+
input_ids = unpad_input(input_ids.unsqueeze(-1), attention_mask)[0].transpose(0, 1)
|
| 48 |
+
position_ids = unpad_input(position_ids.unsqueeze(-1), attention_mask)[0].transpose(0, 1)
|
| 49 |
+
return input_ids, position_ids
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _fsdp_activation_offloading_test(rank, world_size, rendezvous_file, strategy="fsdp"):
|
| 53 |
+
get_torch_device().set_device(rank)
|
| 54 |
+
torch.distributed.init_process_group(
|
| 55 |
+
backend=get_nccl_backend(),
|
| 56 |
+
init_method=f"file://{rendezvous_file}",
|
| 57 |
+
rank=rank,
|
| 58 |
+
world_size=world_size,
|
| 59 |
+
)
|
| 60 |
+
device_mesh = init_device_mesh(get_device_name(), mesh_shape=(world_size,), mesh_dim_names=("dp",))
|
| 61 |
+
|
| 62 |
+
model_name = os.path.expanduser("~/models/Qwen/Qwen2.5-0.5B-Instruct")
|
| 63 |
+
config = Qwen2Config(num_hidden_layers=4)
|
| 64 |
+
|
| 65 |
+
with torch.device(get_device_name()):
|
| 66 |
+
model = AutoModelForCausalLM.from_config(
|
| 67 |
+
config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
|
| 68 |
+
)
|
| 69 |
+
model = model.to(device=get_device_name())
|
| 70 |
+
|
| 71 |
+
# Wrap model with FSDP
|
| 72 |
+
mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32)
|
| 73 |
+
|
| 74 |
+
if strategy == "fsdp":
|
| 75 |
+
model = FSDP(
|
| 76 |
+
model,
|
| 77 |
+
use_orig_params=False,
|
| 78 |
+
device_id=get_torch_device().current_device(),
|
| 79 |
+
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
| 80 |
+
mixed_precision=mixed_precision,
|
| 81 |
+
device_mesh=device_mesh,
|
| 82 |
+
auto_wrap_policy=get_fsdp_wrap_policy(module=model),
|
| 83 |
+
)
|
| 84 |
+
else:
|
| 85 |
+
mp_policy = MixedPrecisionPolicy(
|
| 86 |
+
param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True
|
| 87 |
+
)
|
| 88 |
+
fsdp_kwargs = {
|
| 89 |
+
"mesh": device_mesh,
|
| 90 |
+
"mp_policy": mp_policy,
|
| 91 |
+
}
|
| 92 |
+
apply_fsdp2(model, fsdp_kwargs, {})
|
| 93 |
+
|
| 94 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
| 95 |
+
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
|
| 96 |
+
|
| 97 |
+
# Create checkpoint manager
|
| 98 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 99 |
+
checkpoint_manager = FSDPCheckpointManager(
|
| 100 |
+
model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, tokenizer=tokenizer
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Generate sample input
|
| 104 |
+
batch_size = 2
|
| 105 |
+
seq_len = 32
|
| 106 |
+
vocab_size = 32000
|
| 107 |
+
# First input for initial update
|
| 108 |
+
input_ids1, position_ids1 = create_random_input_ids(batch_size, seq_len, vocab_size)
|
| 109 |
+
|
| 110 |
+
# Second input for verification
|
| 111 |
+
input_ids2, position_ids2 = create_random_input_ids(batch_size, seq_len, vocab_size)
|
| 112 |
+
|
| 113 |
+
# Step 1: Initial update and save checkpoint
|
| 114 |
+
outputs1 = model(input_ids=input_ids1, position_ids=position_ids1)
|
| 115 |
+
loss1 = outputs1.logits.mean()
|
| 116 |
+
loss1.backward()
|
| 117 |
+
optimizer.step()
|
| 118 |
+
lr_scheduler.step()
|
| 119 |
+
optimizer.zero_grad()
|
| 120 |
+
|
| 121 |
+
# Save checkpoint after first update
|
| 122 |
+
temp_dir = tempfile.mkdtemp()
|
| 123 |
+
checkpoint_path = os.path.join(temp_dir, "checkpoint")
|
| 124 |
+
checkpoint_manager.save_checkpoint(local_path=checkpoint_path, hdfs_path=None, global_step=0)
|
| 125 |
+
|
| 126 |
+
# Step 2: Second update and forward pass
|
| 127 |
+
outputs2 = model(input_ids=input_ids2, position_ids=position_ids2)
|
| 128 |
+
loss2 = outputs2.logits.mean()
|
| 129 |
+
loss2.backward()
|
| 130 |
+
optimizer.step()
|
| 131 |
+
lr_scheduler.step()
|
| 132 |
+
optimizer.zero_grad()
|
| 133 |
+
|
| 134 |
+
# Record logits after second update
|
| 135 |
+
with torch.no_grad():
|
| 136 |
+
logits_without_offloading = model(input_ids=input_ids2, position_ids=position_ids2).logits
|
| 137 |
+
|
| 138 |
+
# Step 3: wrap module with activation offloading and load checkpoint
|
| 139 |
+
enable_activation_offloading(model, strategy=strategy)
|
| 140 |
+
checkpoint_manager.load_checkpoint(checkpoint_path)
|
| 141 |
+
|
| 142 |
+
# Step 4: Repeat the second update with same input
|
| 143 |
+
outputs3 = model(input_ids=input_ids2, position_ids=position_ids2)
|
| 144 |
+
loss3 = outputs3.logits.mean()
|
| 145 |
+
loss3.backward()
|
| 146 |
+
optimizer.step()
|
| 147 |
+
lr_scheduler.step()
|
| 148 |
+
optimizer.zero_grad()
|
| 149 |
+
|
| 150 |
+
# Record logits after loaded checkpoint and update
|
| 151 |
+
with torch.no_grad():
|
| 152 |
+
logits_with_offloading = model(input_ids=input_ids2, position_ids=position_ids2).logits
|
| 153 |
+
|
| 154 |
+
# Step 4: Verify outputs match
|
| 155 |
+
torch.testing.assert_close(logits_without_offloading, logits_with_offloading, atol=0.0, rtol=0.0)
|
| 156 |
+
print(f"Activaiton offloading for {strategy} test passed on {world_size} GPUs!")
|
| 157 |
+
|
| 158 |
+
# Cleanup
|
| 159 |
+
shutil.rmtree(temp_dir)
|
| 160 |
+
torch.distributed.barrier()
|
| 161 |
+
torch.distributed.destroy_process_group()
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
@pytest.mark.parametrize("world_size", (2, 4))
|
| 165 |
+
@pytest.mark.parametrize("strategy", ("fsdp", "fsdp2"))
|
| 166 |
+
def test_activation_offloading(world_size, strategy, tmp_path):
|
| 167 |
+
rendezvous_file = str(tmp_path / "rdzv_file")
|
| 168 |
+
os.makedirs(os.path.dirname(rendezvous_file), exist_ok=True)
|
| 169 |
+
|
| 170 |
+
mp.spawn(
|
| 171 |
+
fn=_fsdp_activation_offloading_test,
|
| 172 |
+
args=(world_size, rendezvous_file, strategy),
|
| 173 |
+
nprocs=world_size,
|
| 174 |
+
join=True,
|
| 175 |
+
)
|
code/RL_model/verl/verl_train/tests/utils/test_check_ipc_version_support_on_npu.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# This code is licensed under the MIT-style license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import unittest
|
| 8 |
+
from unittest.mock import Mock, mock_open, patch
|
| 9 |
+
|
| 10 |
+
from verl.utils.device import check_ipc_version_support, get_npu_versions
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class TestCheckIPCVersionSupport(unittest.TestCase):
|
| 14 |
+
"""Test cases for the check_ipc_version_support function."""
|
| 15 |
+
|
| 16 |
+
def setUp(self):
|
| 17 |
+
"""Set up test logging to suppress INFO messages."""
|
| 18 |
+
# Suppress INFO log messages during testing
|
| 19 |
+
logging.disable(logging.INFO)
|
| 20 |
+
|
| 21 |
+
def tearDown(self):
|
| 22 |
+
"""Restore logging."""
|
| 23 |
+
logging.disable(logging.NOTSET)
|
| 24 |
+
|
| 25 |
+
def test_standard_version_with_support(self):
|
| 26 |
+
"""Test standard version that meets minimum requirements."""
|
| 27 |
+
# Software 25.5.0 >= 25.3.rc1, CANN 8.3.0 >= 8.3.rc1
|
| 28 |
+
result = check_ipc_version_support("25.5.0", "8.3.0")
|
| 29 |
+
self.assertTrue(result)
|
| 30 |
+
|
| 31 |
+
def test_standard_version_newer(self):
|
| 32 |
+
"""Test newer standard versions."""
|
| 33 |
+
# Software 26.0.0 >= 25.3.rc1, CANN 9.0.0 >= 8.3.rc1
|
| 34 |
+
result = check_ipc_version_support("26.0.0", "9.0.0")
|
| 35 |
+
self.assertTrue(result)
|
| 36 |
+
|
| 37 |
+
def test_rc_version_format(self):
|
| 38 |
+
"""Test RC version format with additional parts."""
|
| 39 |
+
# Software 25.3.rc1.2 -> 25.3.rc1 >= 25.3.rc1
|
| 40 |
+
# CANN 8.3.rc1.2 -> 8.3.rc1 >= 8.3.rc1
|
| 41 |
+
result = check_ipc_version_support("25.3.rc1.2", "8.3.rc1.2")
|
| 42 |
+
self.assertTrue(result)
|
| 43 |
+
|
| 44 |
+
def test_exact_rc_version(self):
|
| 45 |
+
"""Test exact RC version."""
|
| 46 |
+
# Software 25.3.rc1 >= 25.3.rc1
|
| 47 |
+
# CANN 8.3.rc1 >= 8.3.rc1
|
| 48 |
+
result = check_ipc_version_support("25.3.rc1", "8.3.rc1")
|
| 49 |
+
self.assertTrue(result)
|
| 50 |
+
|
| 51 |
+
def test_t_suffix_version(self):
|
| 52 |
+
"""Test version with lowercase t suffix."""
|
| 53 |
+
# Software 25.5.t3.b001 -> 25.5 >= 25.3.rc1
|
| 54 |
+
# CANN 8.3.rc1 >= 8.3.rc1
|
| 55 |
+
result = check_ipc_version_support("25.5.t3.b001", "8.3.rc1")
|
| 56 |
+
self.assertTrue(result)
|
| 57 |
+
|
| 58 |
+
def test_t_suffix_version_older(self):
|
| 59 |
+
"""Test version with lowercase t suffix that's too old."""
|
| 60 |
+
# Software 25.5.t3.b001 -> 25.5 >= 25.3.rc1 (should pass)
|
| 61 |
+
# CANN 8.2.rc1 < 8.3.rc1 (should fail)
|
| 62 |
+
result = check_ipc_version_support("25.5.t3.b001", "8.2.rc1")
|
| 63 |
+
self.assertFalse(result)
|
| 64 |
+
|
| 65 |
+
def test_software_version_below_minimum(self):
|
| 66 |
+
"""Test software version below minimum requirement."""
|
| 67 |
+
# Software 25.2.0 < 25.3.rc1
|
| 68 |
+
result = check_ipc_version_support("25.2.0", "8.3.0")
|
| 69 |
+
self.assertFalse(result)
|
| 70 |
+
|
| 71 |
+
def test_cann_version_below_minimum(self):
|
| 72 |
+
"""Test CANN version below minimum requirement."""
|
| 73 |
+
# Software 25.5.0 >= 25.3.rc1
|
| 74 |
+
# CANN 8.2.0 < 8.3.rc1
|
| 75 |
+
result = check_ipc_version_support("25.5.0", "8.2.0")
|
| 76 |
+
self.assertFalse(result)
|
| 77 |
+
|
| 78 |
+
def test_both_versions_below_minimum(self):
|
| 79 |
+
"""Test both versions below minimum requirement."""
|
| 80 |
+
# Software 25.2.0 < 25.3.rc1
|
| 81 |
+
# CANN 8.2.0 < 8.3.rc1
|
| 82 |
+
result = check_ipc_version_support("25.2.0", "8.2.0")
|
| 83 |
+
self.assertFalse(result)
|
| 84 |
+
|
| 85 |
+
def test_invalid_software_version(self):
|
| 86 |
+
"""Test invalid software version format."""
|
| 87 |
+
with self.assertRaises(RuntimeError) as context:
|
| 88 |
+
check_ipc_version_support("invalid.version", "8.3.0")
|
| 89 |
+
self.assertIn("Invalid software version format", str(context.exception))
|
| 90 |
+
|
| 91 |
+
def test_invalid_cann_version(self):
|
| 92 |
+
"""Test invalid CANN version format."""
|
| 93 |
+
with self.assertRaises(RuntimeError) as context:
|
| 94 |
+
check_ipc_version_support("25.5.0", "invalid.version")
|
| 95 |
+
self.assertIn("Invalid CANN version format", str(context.exception))
|
| 96 |
+
|
| 97 |
+
def test_rc_with_more_parts(self):
|
| 98 |
+
"""Test RC version with more than 3 parts."""
|
| 99 |
+
# Should extract only first 3 parts: 25.3.rc1
|
| 100 |
+
result = check_ipc_version_support("25.3.rc1.2.3.4", "8.3.rc1.2.3.4")
|
| 101 |
+
self.assertTrue(result)
|
| 102 |
+
|
| 103 |
+
def test_standard_with_more_parts(self):
|
| 104 |
+
"""Test standard version with more than 3 parts."""
|
| 105 |
+
# Should extract only first 3 parts: 25.5.0
|
| 106 |
+
result = check_ipc_version_support("25.5.0.1.2.3", "8.3.0.1.2.3")
|
| 107 |
+
self.assertTrue(result)
|
| 108 |
+
|
| 109 |
+
def test_rc_edge_case_versions(self):
|
| 110 |
+
"""Test edge case RC versions."""
|
| 111 |
+
# RC1 is the minimum
|
| 112 |
+
result = check_ipc_version_support("25.3.rc1", "8.3.rc1")
|
| 113 |
+
self.assertTrue(result)
|
| 114 |
+
|
| 115 |
+
# RC0 should fail
|
| 116 |
+
result = check_ipc_version_support("25.3.rc0", "8.3.rc1")
|
| 117 |
+
self.assertFalse(result)
|
| 118 |
+
|
| 119 |
+
def test_major_version_differences(self):
|
| 120 |
+
"""Test major version number differences."""
|
| 121 |
+
# Much newer major versions
|
| 122 |
+
result = check_ipc_version_support("30.0.0", "10.0.0")
|
| 123 |
+
self.assertTrue(result)
|
| 124 |
+
|
| 125 |
+
# Older major versions
|
| 126 |
+
result = check_ipc_version_support("24.0.0", "7.0.0")
|
| 127 |
+
self.assertFalse(result)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class TestGetNPUVersions(unittest.TestCase):
|
| 131 |
+
"""Test cases for the get_npu_versions function."""
|
| 132 |
+
|
| 133 |
+
@patch("subprocess.run")
|
| 134 |
+
@patch("platform.machine")
|
| 135 |
+
@patch("os.path.exists")
|
| 136 |
+
@patch("builtins.open", new_callable=mock_open, read_data="version=8.3.rc1\n")
|
| 137 |
+
def test_get_npu_versions_success(self, mock_file, mock_exists, mock_machine, mock_run):
|
| 138 |
+
"""Test successful retrieval of versions."""
|
| 139 |
+
# Mock npu-smi output
|
| 140 |
+
mock_run.return_value = Mock(stdout="Software Version : 25.5.0\nOther Info\n", check=True)
|
| 141 |
+
|
| 142 |
+
# Mock architecture
|
| 143 |
+
mock_machine.return_value = "x86_64"
|
| 144 |
+
|
| 145 |
+
# Mock path exists
|
| 146 |
+
mock_exists.return_value = True
|
| 147 |
+
|
| 148 |
+
software_version, cann_version = get_npu_versions()
|
| 149 |
+
|
| 150 |
+
self.assertEqual(software_version, "25.5.0")
|
| 151 |
+
self.assertEqual(cann_version, "8.3.rc1")
|
| 152 |
+
|
| 153 |
+
@patch("subprocess.run")
|
| 154 |
+
def test_get_npu_versions_missing_software_version(self, mock_run):
|
| 155 |
+
"""Test error when Software Version is missing."""
|
| 156 |
+
mock_run.return_value = Mock(stdout="Other Info Without Software Version\n", check=True)
|
| 157 |
+
|
| 158 |
+
with self.assertRaises(RuntimeError) as context:
|
| 159 |
+
get_npu_versions()
|
| 160 |
+
|
| 161 |
+
self.assertIn("Could not find Software Version", str(context.exception))
|
| 162 |
+
|
| 163 |
+
@patch("subprocess.run")
|
| 164 |
+
@patch("platform.machine")
|
| 165 |
+
@patch("os.path.exists")
|
| 166 |
+
@patch("builtins.open", new_callable=mock_open, read_data="version=8.3.rc1\n")
|
| 167 |
+
def test_get_npu_versions_unsupported_architecture(self, mock_file, mock_exists, mock_machine, mock_run):
|
| 168 |
+
"""Test error with unsupported architecture."""
|
| 169 |
+
mock_run.return_value = Mock(stdout="Software Version : 25.5.0\n", check=True)
|
| 170 |
+
|
| 171 |
+
mock_machine.return_value = "armv7l" # Unsupported architecture
|
| 172 |
+
mock_exists.return_value = True
|
| 173 |
+
|
| 174 |
+
with self.assertRaises(RuntimeError) as context:
|
| 175 |
+
get_npu_versions()
|
| 176 |
+
|
| 177 |
+
self.assertIn("Unsupported architecture", str(context.exception))
|
| 178 |
+
|
| 179 |
+
@patch("subprocess.run")
|
| 180 |
+
@patch("platform.machine")
|
| 181 |
+
@patch("os.path.exists")
|
| 182 |
+
@patch("builtins.open", new_callable=mock_open, read_data="version=8.3.rc1\n")
|
| 183 |
+
def test_get_npu_versions_cann_path_not_exists(self, mock_file, mock_exists, mock_machine, mock_run):
|
| 184 |
+
"""Test error when CANN path doesn't exist."""
|
| 185 |
+
mock_run.return_value = Mock(stdout="Software Version : 25.5.0\n", check=True)
|
| 186 |
+
|
| 187 |
+
mock_machine.return_value = "x86_64"
|
| 188 |
+
mock_exists.return_value = False # Path doesn't exist
|
| 189 |
+
|
| 190 |
+
with self.assertRaises(RuntimeError) as context:
|
| 191 |
+
get_npu_versions()
|
| 192 |
+
|
| 193 |
+
self.assertIn("CANN toolkit path does not exist", str(context.exception))
|
| 194 |
+
|
| 195 |
+
@patch("subprocess.run")
|
| 196 |
+
@patch("platform.machine")
|
| 197 |
+
@patch("os.path.exists")
|
| 198 |
+
@patch("builtins.open")
|
| 199 |
+
def test_get_npu_versions_info_file_not_exists(self, mock_file, mock_exists, mock_machine, mock_run):
|
| 200 |
+
"""Test error when CANN info file doesn't exist."""
|
| 201 |
+
mock_run.return_value = Mock(stdout="Software Version : 25.5.0\n", check=True)
|
| 202 |
+
|
| 203 |
+
mock_machine.return_value = "x86_64"
|
| 204 |
+
|
| 205 |
+
# First call is for CANN path exists, second call is for info file exists
|
| 206 |
+
mock_exists.side_effect = [True, False]
|
| 207 |
+
|
| 208 |
+
with self.assertRaises(RuntimeError) as context:
|
| 209 |
+
get_npu_versions()
|
| 210 |
+
|
| 211 |
+
self.assertIn("CANN toolkit info file does not exist", str(context.exception))
|
| 212 |
+
|
| 213 |
+
@patch("subprocess.run")
|
| 214 |
+
@patch("platform.machine")
|
| 215 |
+
@patch("os.path.exists")
|
| 216 |
+
@patch("builtins.open", new_callable=mock_open, read_data="other_info=no_version\n")
|
| 217 |
+
def test_get_npu_versions_missing_cann_version(self, mock_file, mock_exists, mock_machine, mock_run):
|
| 218 |
+
"""Test error when CANN version is missing from info file."""
|
| 219 |
+
mock_run.return_value = Mock(stdout="Software Version : 25.5.0\n", check=True)
|
| 220 |
+
|
| 221 |
+
mock_machine.return_value = "x86_64"
|
| 222 |
+
mock_exists.return_value = True
|
| 223 |
+
|
| 224 |
+
with self.assertRaises(RuntimeError) as context:
|
| 225 |
+
get_npu_versions()
|
| 226 |
+
|
| 227 |
+
self.assertIn("Could not find version in CANN toolkit info file", str(context.exception))
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
if __name__ == "__main__":
|
| 231 |
+
unittest.main()
|
code/RL_model/verl/verl_train/tests/utils/test_config_on_cpu.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import unittest
|
| 16 |
+
from dataclasses import dataclass, field
|
| 17 |
+
|
| 18 |
+
from omegaconf import OmegaConf
|
| 19 |
+
|
| 20 |
+
from verl.base_config import BaseConfig
|
| 21 |
+
from verl.utils import omega_conf_to_dataclass
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class TestDataclass(BaseConfig):
|
| 26 |
+
hidden_size: int = 0
|
| 27 |
+
activation: str = "relu"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class TestTrainConfig(BaseConfig):
|
| 32 |
+
batch_size: int = 0
|
| 33 |
+
model: TestDataclass = field(default_factory=TestDataclass)
|
| 34 |
+
override_config: dict = field(default_factory=dict)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
_cfg_str = """train_config:
|
| 38 |
+
_target_: tests.utils.test_config_on_cpu.TestTrainConfig
|
| 39 |
+
batch_size: 32
|
| 40 |
+
model:
|
| 41 |
+
hidden_size: 768
|
| 42 |
+
activation: relu
|
| 43 |
+
override_config: {}"""
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class TestConfigOnCPU(unittest.TestCase):
|
| 47 |
+
"""Test cases for configuration utilities on CPU.
|
| 48 |
+
|
| 49 |
+
Test Plan:
|
| 50 |
+
1. Test basic OmegaConf to dataclass conversion for simple nested structures
|
| 51 |
+
2. Test nested OmegaConf to dataclass conversion for complex hierarchical configurations
|
| 52 |
+
3. Verify all configuration values are correctly converted and accessible
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def setUp(self):
|
| 56 |
+
self.config = OmegaConf.create(_cfg_str)
|
| 57 |
+
|
| 58 |
+
def test_omega_conf_to_dataclass(self):
|
| 59 |
+
sub_cfg = self.config.train_config.model
|
| 60 |
+
cfg = omega_conf_to_dataclass(sub_cfg, TestDataclass)
|
| 61 |
+
self.assertEqual(cfg.hidden_size, 768)
|
| 62 |
+
self.assertEqual(cfg.activation, "relu")
|
| 63 |
+
assert isinstance(cfg, TestDataclass)
|
| 64 |
+
|
| 65 |
+
def test_nested_omega_conf_to_dataclass(self):
|
| 66 |
+
cfg = omega_conf_to_dataclass(self.config.train_config, TestTrainConfig)
|
| 67 |
+
self.assertEqual(cfg.batch_size, 32)
|
| 68 |
+
self.assertEqual(cfg.model.hidden_size, 768)
|
| 69 |
+
self.assertEqual(cfg.model.activation, "relu")
|
| 70 |
+
assert isinstance(cfg, TestTrainConfig)
|
| 71 |
+
assert isinstance(cfg.model, TestDataclass)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class TestPrintCfgCommand(unittest.TestCase):
|
| 75 |
+
"""Test suite for the print_cfg.py command-line tool."""
|
| 76 |
+
|
| 77 |
+
def test_command_with_override(self):
|
| 78 |
+
"""Test that the command runs without error when overriding config values."""
|
| 79 |
+
import subprocess
|
| 80 |
+
|
| 81 |
+
# Run the command
|
| 82 |
+
result = subprocess.run(
|
| 83 |
+
["python3", "scripts/print_cfg.py"],
|
| 84 |
+
capture_output=True,
|
| 85 |
+
text=True,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# Verify the command exited successfully
|
| 89 |
+
self.assertEqual(result.returncode, 0, f"Command failed with stderr: {result.stderr}")
|
| 90 |
+
|
| 91 |
+
# Verify the output contains expected config information
|
| 92 |
+
self.assertIn("critic", result.stdout)
|
| 93 |
+
self.assertIn("profiler", result.stdout)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
if __name__ == "__main__":
|
| 97 |
+
unittest.main()
|
code/RL_model/verl/verl_train/tests/utils/test_flops_counter.py
ADDED
|
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
|
| 17 |
+
import pytest
|
| 18 |
+
|
| 19 |
+
from verl.utils.flops_counter import FlopsCounter
|
| 20 |
+
|
| 21 |
+
VALID_CONFIG_TYPE = {"llama", "qwen2", "qwen3", "qwen3_moe", "deepseek_v3", "mistral", "gemma3_text", "apertus"}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Config:
|
| 25 |
+
def __init__(self, config_dict):
|
| 26 |
+
for key, value in config_dict.items():
|
| 27 |
+
if isinstance(value, dict):
|
| 28 |
+
value = Config(value)
|
| 29 |
+
setattr(self, key, value)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
CONFIG = {
|
| 33 |
+
"llama": {
|
| 34 |
+
"config": { # llama2-7B
|
| 35 |
+
"model_type": "llama",
|
| 36 |
+
"vocab_size": 32000,
|
| 37 |
+
"hidden_size": 4096,
|
| 38 |
+
"intermediate_size": 11008,
|
| 39 |
+
"num_hidden_layers": 32,
|
| 40 |
+
"num_attention_heads": 32,
|
| 41 |
+
"num_key_value_heads": 32,
|
| 42 |
+
},
|
| 43 |
+
"batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]),
|
| 44 |
+
# 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+ hidden*inter*3))*token_sum +
|
| 45 |
+
# 6*sum(seqlen^2)*layer*head*head_dim
|
| 46 |
+
# 6*(32000*4096*2+32*(4096*4096*4+4096*11008*3))*(512+1024+2048) +
|
| 47 |
+
# 6*(512*512+1024*1024+2048*2048)*32*4096
|
| 48 |
+
# 6*(32000*4096*2+32*(4096*4096*4+4096*11008*3))*(4096+4096+4096) +
|
| 49 |
+
# 6*(4096*4096+4096*4096+4096*4096)*32*4096
|
| 50 |
+
"expected_flops_tuple": (149226491215872 / 1e12, 536372695793664 / 1e12),
|
| 51 |
+
},
|
| 52 |
+
"qwen2": {
|
| 53 |
+
"config": { # Qwen/Qwen2.5-7B-Instruct
|
| 54 |
+
"model_type": "qwen2",
|
| 55 |
+
"vocab_size": 152064,
|
| 56 |
+
"hidden_size": 3584,
|
| 57 |
+
"intermediate_size": 18944,
|
| 58 |
+
"num_hidden_layers": 28,
|
| 59 |
+
"num_attention_heads": 28,
|
| 60 |
+
"num_key_value_heads": 4,
|
| 61 |
+
},
|
| 62 |
+
"batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]),
|
| 63 |
+
# 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+ hidden*inter*3))*token_sum +
|
| 64 |
+
# 6*sum(seqlen^2)*layer*head*head_dim
|
| 65 |
+
# 6*(152064*3584*2+28*(3584*(3584+512+512+3584)+3584*18944*3))*(512+1024+2048) +
|
| 66 |
+
# 6*(512*512+1024*1024+2048*2048)*28*3584
|
| 67 |
+
# 6*(152064*3584*2+28*(3584*(3584+512+512+3584)+3584*18944*3))*(4096+4096+4096) +
|
| 68 |
+
# 6*(4096*4096+4096*4096+4096*4096)*28*3584
|
| 69 |
+
"expected_flops_tuple": (167073690943488 / 1e12, 591764889010176 / 1e12),
|
| 70 |
+
},
|
| 71 |
+
"qwen3": {
|
| 72 |
+
"config": { # Qwen/Qwen3-8B
|
| 73 |
+
"model_type": "qwen3",
|
| 74 |
+
"vocab_size": 151936,
|
| 75 |
+
"hidden_size": 4096,
|
| 76 |
+
"intermediate_size": 12288,
|
| 77 |
+
"num_hidden_layers": 36,
|
| 78 |
+
"num_attention_heads": 32,
|
| 79 |
+
"num_key_value_heads": 8,
|
| 80 |
+
"head_dim": 128,
|
| 81 |
+
},
|
| 82 |
+
"batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]),
|
| 83 |
+
# 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+ hidden*inter*3))*token_sum +
|
| 84 |
+
# 6*sum(seqlen^2)*layer*head*head_dim
|
| 85 |
+
# 6*(151936*4096*2+36*(4096*(128*32+128*8*2+128*32)+4096*12288*3))*(512+1024+2048) +
|
| 86 |
+
# 6*(512*512+1024*1024+2048*2048)*36*128*32
|
| 87 |
+
# 6*(151936*4096*2+36*(4096*(128*32+128*8*2+128*32)+4096*12288*3))*(4096+4096+4096) +
|
| 88 |
+
# 6*(4096*4096+4096*4096+4096*4096)*36*128*32
|
| 89 |
+
"expected_flops_tuple": (180997438046208 / 1e12, 648394032807936 / 1e12),
|
| 90 |
+
},
|
| 91 |
+
"qwen3_moe": {
|
| 92 |
+
"config": { # Qwen/Qwen3-30B-A3B-Base
|
| 93 |
+
"model_type": "qwen3_moe",
|
| 94 |
+
"hidden_size": 2048,
|
| 95 |
+
"vocab_size": 151936,
|
| 96 |
+
"num_hidden_layers": 48,
|
| 97 |
+
"num_key_value_heads": 4,
|
| 98 |
+
"num_attention_heads": 32,
|
| 99 |
+
"head_dim": 128,
|
| 100 |
+
"moe_intermediate_size": 768,
|
| 101 |
+
"num_experts_per_tok": 8,
|
| 102 |
+
"num_experts": 128,
|
| 103 |
+
},
|
| 104 |
+
"batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]),
|
| 105 |
+
# 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+hidden*inter*top_k_exp*3 +
|
| 106 |
+
# hidden*num_experts))*token_sum + 6*sum(seqlen^2)*layer*head*head_dim
|
| 107 |
+
# 6*(151936*2048*2+48*(2048*(128*32+128*4*2+128*32)+2048*768*8*3+2048*128))*(512+1024+2048) +
|
| 108 |
+
# 6*(512*512+1024*1024+2048*2048)*48*128*32
|
| 109 |
+
# 6*(151936*2048*2+48*(2048*(128*32+128*4*2+128*32)+2048*768*8*3+2048*128))*(4096+4096+4096) +
|
| 110 |
+
# 6*(4096*4096+4096*4096+4096*4096)*48*128*32
|
| 111 |
+
"expected_flops_tuple": (78593069678592 / 1e12, 306570470621184 / 1e12),
|
| 112 |
+
},
|
| 113 |
+
"deepseek_v3": {
|
| 114 |
+
"config": { # deepseek-ai/DeepSeek-Prover-V2-671B
|
| 115 |
+
"model_type": "deepseek_v3",
|
| 116 |
+
"hidden_size": 7168,
|
| 117 |
+
"vocab_size": 129280,
|
| 118 |
+
"moe_intermediate_size": 2048,
|
| 119 |
+
"num_hidden_layers": 61,
|
| 120 |
+
"first_k_dense_replace": 3,
|
| 121 |
+
"num_attention_heads": 128,
|
| 122 |
+
"n_routed_experts": 256,
|
| 123 |
+
"num_experts_per_tok": 8,
|
| 124 |
+
"n_shared_experts": 1,
|
| 125 |
+
"kv_lora_rank": 512,
|
| 126 |
+
"qk_rope_head_dim": 64,
|
| 127 |
+
"v_head_dim": 128,
|
| 128 |
+
"intermediate_size": 18432,
|
| 129 |
+
"qk_nope_head_dim": 128,
|
| 130 |
+
"q_lora_rank": 1536,
|
| 131 |
+
},
|
| 132 |
+
"batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]),
|
| 133 |
+
# (1536*7168+128*192*1536+7168*(512+64)+128*(128+128)*512+128*128*7168) = 187105280
|
| 134 |
+
# 6*(129280*7168*2+ 3*(7168*18432*3+187105280)+ 58*(187105280+7168*256+7168*2048*9*3))*(512+1024+2048) +
|
| 135 |
+
# 3*(512*512+1024*1024+2048*2048)*61*(192+128)*128
|
| 136 |
+
# 6*(129280*7168*2+ 3*(7168*18432*3+187105280)+ 58*(187105280+7168*256+7168*2048*9*3))*(4096+4096+4096) +
|
| 137 |
+
# 3*(4096*4096+4096*4096+4096*4096)*61*(192+128)*128
|
| 138 |
+
"expected_flops_tuple": (848766538088448 / 1e12, 3145850406567936 / 1e12),
|
| 139 |
+
},
|
| 140 |
+
"mistral": {
|
| 141 |
+
"config": { # mistralai/Mistral-Small-24B-Instruct-2501
|
| 142 |
+
"model_type": "mistral",
|
| 143 |
+
"vocab_size": 131072,
|
| 144 |
+
"hidden_size": 5120,
|
| 145 |
+
"intermediate_size": 32768,
|
| 146 |
+
"num_hidden_layers": 40,
|
| 147 |
+
"num_attention_heads": 32,
|
| 148 |
+
"num_key_value_heads": 8,
|
| 149 |
+
"head_dim": 128,
|
| 150 |
+
},
|
| 151 |
+
"batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]),
|
| 152 |
+
# Mistral uses same architecture as Llama, with GQA
|
| 153 |
+
# 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+ hidden*inter*3))*token_sum +
|
| 154 |
+
# 12*sum(seqlen^2)*layer*head*head_dim
|
| 155 |
+
# vocab part: 131072*5120*2 = 1342177280
|
| 156 |
+
# attn part per layer: 5120*(128*32+128*8+128*8+128*32) = 5120*10240 = 52428800
|
| 157 |
+
# mlp part per layer: 5120*32768*3 = 503316480
|
| 158 |
+
# total per layer: 52428800 + 503316480 = 555745280
|
| 159 |
+
# all layers: 1342177280 + 40*555745280 = 23571988480
|
| 160 |
+
# For batch [512, 1024, 2048], tokens_sum = 3584:
|
| 161 |
+
# dense flops: 6 * 23571988480 * 3584 = 506892040273920
|
| 162 |
+
# attn flops: 6 * 5505024 * 40 * 128 * 32 = 10823317585920
|
| 163 |
+
# total: 517715357859840 / 1e12 = 517.71535785984
|
| 164 |
+
# For batch [4096, 4096, 4096], tokens_sum = 12288:
|
| 165 |
+
# dense flops: 6 * 23571988480 * 12288 = 1737915566653440
|
| 166 |
+
# attn flops: 6 * 50331648 * 40 * 128 * 32 = 98956046499840
|
| 167 |
+
# total: 1836871613153280 / 1e12 = 1836.87161315328
|
| 168 |
+
"expected_flops_tuple": (512303699066880 / 1e12, 1787393589903360 / 1e12),
|
| 169 |
+
},
|
| 170 |
+
"gemma3_text": {
|
| 171 |
+
"config": { # Gemma3-12B-IT-TextOnly
|
| 172 |
+
"model_type": "gemma3_text",
|
| 173 |
+
"vocab_size": 262208,
|
| 174 |
+
"hidden_size": 3840,
|
| 175 |
+
"intermediate_size": 15360,
|
| 176 |
+
"num_hidden_layers": 48,
|
| 177 |
+
"num_attention_heads": 16,
|
| 178 |
+
"num_key_value_heads": 8,
|
| 179 |
+
"head_dim": 256,
|
| 180 |
+
"sliding_window": 1024,
|
| 181 |
+
"layer_types": None,
|
| 182 |
+
# Will be auto-generated based on sliding_window_pattern
|
| 183 |
+
"sliding_window_pattern": 6,
|
| 184 |
+
# Every 6th layer is full attention
|
| 185 |
+
},
|
| 186 |
+
"batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]),
|
| 187 |
+
# Gemma3 has alternating sliding window attention
|
| 188 |
+
# With sliding_window_pattern=6: layers 5,11,17,23,29,35,41,47 use full attention (8 layers)
|
| 189 |
+
# Other 40 layers use sliding window attention with window_size=1024
|
| 190 |
+
#
|
| 191 |
+
# Non-attention FLOPs:
|
| 192 |
+
# vocab part: 262208*3840*2 = 2013757440
|
| 193 |
+
# attn part per layer: 3840*(256*16+256*8+256*8+256*16) = 3840*12288 = 47185920
|
| 194 |
+
# mlp part per layer: 3840*15360*3 = 176947200
|
| 195 |
+
# total per layer: 47185920 + 176947200 = 224133120
|
| 196 |
+
# all layers: 2013757440 + 48*224133120 = 12772147200
|
| 197 |
+
#
|
| 198 |
+
# For batch [512, 1024, 2048], tokens_sum = 3584:
|
| 199 |
+
# dense flops: 6 * 12772147200 * 3584 = 274652253388800
|
| 200 |
+
# seqlen_square_sum: 180355072 (calculated with sliding window logic)
|
| 201 |
+
# attn flops: 6 * 180355072 * 256 * 16 = 8864812498944
|
| 202 |
+
# total: 283517065887744 / 1e12 = 283.517065887744
|
| 203 |
+
#
|
| 204 |
+
# For batch [4096, 4096, 4096], tokens_sum = 12288:
|
| 205 |
+
# dense flops: 6 * 12772147200 * 12288 = 941664868761600
|
| 206 |
+
# seqlen_square_sum: 905969664 (calculated with sliding window logic)
|
| 207 |
+
# attn flops: 6 * 905969664 * 256 * 16 = 44530220924928
|
| 208 |
+
# total: 986195089686528 / 1e12 = 986.195089686528
|
| 209 |
+
"expected_flops_tuple": (279084659638272 / 1e12, 963929979224064 / 1e12),
|
| 210 |
+
},
|
| 211 |
+
"gpt_oss": {
|
| 212 |
+
"config": {
|
| 213 |
+
"model_type": "gpt_oss",
|
| 214 |
+
"vocab_size": 201088,
|
| 215 |
+
"hidden_size": 2880,
|
| 216 |
+
"num_hidden_layers": 24,
|
| 217 |
+
"num_attention_heads": 64,
|
| 218 |
+
"num_key_value_heads": 8,
|
| 219 |
+
"head_dim": 64,
|
| 220 |
+
"intermediate_size": 2880,
|
| 221 |
+
"num_local_experts": 32,
|
| 222 |
+
"num_experts_per_tok": 4,
|
| 223 |
+
"sliding_window": 128,
|
| 224 |
+
"layer_types": [
|
| 225 |
+
"sliding_attention",
|
| 226 |
+
"full_attention",
|
| 227 |
+
"sliding_attention",
|
| 228 |
+
"full_attention",
|
| 229 |
+
"sliding_attention",
|
| 230 |
+
"full_attention",
|
| 231 |
+
"sliding_attention",
|
| 232 |
+
"full_attention",
|
| 233 |
+
"sliding_attention",
|
| 234 |
+
"full_attention",
|
| 235 |
+
"sliding_attention",
|
| 236 |
+
"full_attention",
|
| 237 |
+
"sliding_attention",
|
| 238 |
+
"full_attention",
|
| 239 |
+
"sliding_attention",
|
| 240 |
+
"full_attention",
|
| 241 |
+
"sliding_attention",
|
| 242 |
+
"full_attention",
|
| 243 |
+
"sliding_attention",
|
| 244 |
+
"full_attention",
|
| 245 |
+
"sliding_attention",
|
| 246 |
+
"full_attention",
|
| 247 |
+
"sliding_attention",
|
| 248 |
+
"full_attention",
|
| 249 |
+
],
|
| 250 |
+
},
|
| 251 |
+
"batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]),
|
| 252 |
+
# GPT-OSS has alternating sliding / full attention
|
| 253 |
+
# Even layers (12 layers) use sliding window attention with window_size = 128
|
| 254 |
+
# Odd layers (12 layers) use full attention
|
| 255 |
+
#
|
| 256 |
+
# Non-attention FLOPs:
|
| 257 |
+
# vocab part: 201088 * 2880 * 2 = 1158266880
|
| 258 |
+
# attn linear part per layer:
|
| 259 |
+
# Q: 2880 * (64 * 64) = 11796480
|
| 260 |
+
# K: 2880 * (8 * 64) = 1474560
|
| 261 |
+
# V: 2880 * (8 * 64) = 1474560
|
| 262 |
+
# O: (64 * 64) * 2880 = 11796480
|
| 263 |
+
# attn linear total = 26542080
|
| 264 |
+
# mlp (MoE, SwiGLU) part per layer:
|
| 265 |
+
# gate: 2880 * 32 = 92160
|
| 266 |
+
# active experts: 3 * 2880 * 2880 * 4 = 99532800
|
| 267 |
+
# mlp total = 99624960
|
| 268 |
+
# total per layer: 26542080 + 99624960 = 126167040
|
| 269 |
+
# all layers:
|
| 270 |
+
# 126167040 * 24 = 3028008960
|
| 271 |
+
# total dense params:
|
| 272 |
+
# 3028008960 + 1158266880 = 4186275840
|
| 273 |
+
#
|
| 274 |
+
# For batch [512, 1024, 2048], tokens_sum = 3584:
|
| 275 |
+
# dense flops: 6 * 4186275840 * 3584 = 90021675663360
|
| 276 |
+
# seqlen_square_sum: 71565312 (calculated with sliding window logic)
|
| 277 |
+
# attn flops: 6 * 71565312 * 64 * 64 = 3517578215424
|
| 278 |
+
# total: 93539253878784 / 1e12 = 93.539253878784
|
| 279 |
+
#
|
| 280 |
+
# For batch [4096, 4096, 4096], tokens_sum = 12288:
|
| 281 |
+
# dense flops: 6 * 4186275840 * 12288 = 308646629068800
|
| 282 |
+
# seqlen_square_sum: 622854144 (calculated with sliding window logic)
|
| 283 |
+
# attn flops: 6 * 622854144 * 64 * 64 = 30613642948608
|
| 284 |
+
# total: 339260272017408 / 1e12 = 339.260272017408
|
| 285 |
+
"expected_flops_tuple": (91780464771072 / 1e12, 323953008574464 / 1e12),
|
| 286 |
+
},
|
| 287 |
+
"apertus": {
|
| 288 |
+
"config": { # swiss-ai/Apertus-8B
|
| 289 |
+
"model_type": "apertus",
|
| 290 |
+
"vocab_size": 131072,
|
| 291 |
+
"hidden_size": 4096,
|
| 292 |
+
"intermediate_size": 21504,
|
| 293 |
+
"num_hidden_layers": 32,
|
| 294 |
+
"num_attention_heads": 32,
|
| 295 |
+
"num_key_value_heads": 32,
|
| 296 |
+
"hidden_act": "xielu",
|
| 297 |
+
# head_dim will be derived as 4096 / 32 = 128
|
| 298 |
+
},
|
| 299 |
+
"batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]),
|
| 300 |
+
# Calculation for Apertus (hidden_act="xielu" -> MLP uses [k_mlp=2]*H*I params; qk_norm=True -> [k_qkn=2]*H):
|
| 301 |
+
# V=131072, H=4096, I=21504, L=32, k_mlp=2 (XIELU), k_qkn=2 (QK norm), S=6
|
| 302 |
+
# S*(2*V*H + L*(4*H**2 + k_mlp*H*I + k_qkn*H)) * (SUM[seqlen]) + 6*SUM[seqlen**2]*L*H
|
| 303 |
+
"expected_flops_tuple": (194825353691136 / 1e12, 692711652851712 / 1e12),
|
| 304 |
+
},
|
| 305 |
+
"qwen3_vl": {
|
| 306 |
+
"config": { # Qwen/Qwen3-VL-8B
|
| 307 |
+
"model_type": "qwen3_vl",
|
| 308 |
+
# -------- Text config --------
|
| 309 |
+
"text_config": {
|
| 310 |
+
"vocab_size": 151936,
|
| 311 |
+
"hidden_size": 4096,
|
| 312 |
+
"intermediate_size": 12288,
|
| 313 |
+
"num_hidden_layers": 36,
|
| 314 |
+
"num_attention_heads": 32,
|
| 315 |
+
"num_key_value_heads": 8,
|
| 316 |
+
"head_dim": 128,
|
| 317 |
+
},
|
| 318 |
+
# -------- Vision config (ViT) --------
|
| 319 |
+
"vision_config": {
|
| 320 |
+
"deepstack_visual_indexes": [8, 16, 24],
|
| 321 |
+
"num_heads": 16,
|
| 322 |
+
"depth": 27,
|
| 323 |
+
"hidden_size": 1152,
|
| 324 |
+
"intermediate_size": 4304,
|
| 325 |
+
"out_hidden_size": 4096,
|
| 326 |
+
"spatial_merge_size": 2,
|
| 327 |
+
"temporal_patch_size": 2,
|
| 328 |
+
"in_channels": 3,
|
| 329 |
+
"patch_size": 16,
|
| 330 |
+
},
|
| 331 |
+
},
|
| 332 |
+
"batch_seqlens_tuple": (
|
| 333 |
+
[512, 1024, 2048],
|
| 334 |
+
[4096, 4096, 4096],
|
| 335 |
+
),
|
| 336 |
+
"images_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]),
|
| 337 |
+
# -----Text-----
|
| 338 |
+
# 6*(vocab*hidden*2
|
| 339 |
+
# + layer*(hidden*(q+k+v+o) + hidden*inter*3)
|
| 340 |
+
# )*token_sum
|
| 341 |
+
# + 6*sum(seqlen^2)*layer*hidden
|
| 342 |
+
#
|
| 343 |
+
# -----ViT-----
|
| 344 |
+
# patch_embed_N =hidden*temporal_patch_size*in_channels* patch_size^2
|
| 345 |
+
# attn_linear_N =hidden*(4*hidden)
|
| 346 |
+
# mlp_N =hidden*inter*2
|
| 347 |
+
# merger_N =((o+hidden*spatial_merge_size^2) * (hidden*spatial_merge_size^2))
|
| 348 |
+
# deepstack_merger_N =merger_N * 3
|
| 349 |
+
# dense_N =patch_embed_N + (attn_linear_N + mlp_N) * 27 + deepstack_merger_N + merger_N
|
| 350 |
+
#
|
| 351 |
+
# 6*(151936*4096*2
|
| 352 |
+
# + 36*(4096*(4096+1024+1024+4096) + 4096*12288*3)
|
| 353 |
+
# )*(512+1024+2048)
|
| 354 |
+
# + 12*(512*512+1024*1024+2048*2048)*36*4096
|
| 355 |
+
# + 6 * dense_N * (512 + 1024 + 2048)
|
| 356 |
+
# + 12 * (512**2 + 1024**2 + 2048**2) * 27 * 16 * 72
|
| 357 |
+
#
|
| 358 |
+
# 6*(151936*4096*2
|
| 359 |
+
# + 36*(4096*(4096+1024+1024+4096) + 4096*12288*3)
|
| 360 |
+
# )*(4096+4096+4096)
|
| 361 |
+
# + 12*(4096*4096+4096*4096+4096*4096)*36*4096
|
| 362 |
+
# + 6 * dense_N * (4096 + 4096 + 2048)
|
| 363 |
+
# + 12 * (4096**2 + 4096**2 + 4096**2) * 27 * 16 * 72
|
| 364 |
+
"expected_flops_tuple": (
|
| 365 |
+
195379819708416 / 1e12,
|
| 366 |
+
709446422495232 / 1e12,
|
| 367 |
+
),
|
| 368 |
+
},
|
| 369 |
+
"qwen3_vl_moe": {
|
| 370 |
+
"config": { # Qwen/Qwen3-VL-30B-A3B
|
| 371 |
+
"model_type": "qwen3_vl_moe",
|
| 372 |
+
# -------- Text config --------
|
| 373 |
+
"text_config": {
|
| 374 |
+
"vocab_size": 151936,
|
| 375 |
+
"hidden_size": 2048,
|
| 376 |
+
"num_hidden_layers": 48,
|
| 377 |
+
"num_attention_heads": 32,
|
| 378 |
+
"num_key_value_heads": 4,
|
| 379 |
+
"head_dim": 128,
|
| 380 |
+
"moe_intermediate_size": 768,
|
| 381 |
+
"num_experts": 128,
|
| 382 |
+
"num_experts_per_tok": 8,
|
| 383 |
+
},
|
| 384 |
+
# -------- Vision config (ViT) --------
|
| 385 |
+
"vision_config": {
|
| 386 |
+
"deepstack_visual_indexes": [8, 16, 24],
|
| 387 |
+
"num_heads": 16,
|
| 388 |
+
"depth": 27,
|
| 389 |
+
"hidden_size": 1152,
|
| 390 |
+
"intermediate_size": 4304,
|
| 391 |
+
"out_hidden_size": 4096,
|
| 392 |
+
"spatial_merge_size": 2,
|
| 393 |
+
"temporal_patch_size": 2,
|
| 394 |
+
"in_channels": 3,
|
| 395 |
+
"patch_size": 16,
|
| 396 |
+
},
|
| 397 |
+
},
|
| 398 |
+
"batch_seqlens_tuple": (
|
| 399 |
+
[512, 1024, 2048],
|
| 400 |
+
[4096, 4096, 4096],
|
| 401 |
+
),
|
| 402 |
+
"images_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]),
|
| 403 |
+
# -----Text-----
|
| 404 |
+
# 6*(vocab*hidden*2
|
| 405 |
+
# + layer*(hidden*(q+k+v+head*head_dim)+hidden*inter*top_k_exp*3+hidden*num_experts)
|
| 406 |
+
# )*token_sum
|
| 407 |
+
# + 6*sum(seqlen^2)*layer*hidden
|
| 408 |
+
#
|
| 409 |
+
# -----ViT-----
|
| 410 |
+
# patch_embed_N =hidden*temporal_patch_size*in_channels* patch_size^2
|
| 411 |
+
# attn_linear_N =hidden*(4*hidden)
|
| 412 |
+
# mlp_N =hidden*inter*2
|
| 413 |
+
# merger_N =((o+hidden*spatial_merge_size^2) * (hidden*spatial_merge_size^2))
|
| 414 |
+
# deepstack_merger_N =merger_N * 3
|
| 415 |
+
# dense_N =patch_embed_N + (attn_linear_N + mlp_N) * 27 + deepstack_merger_N + merger_N
|
| 416 |
+
#
|
| 417 |
+
# 6*(151936*2048*2
|
| 418 |
+
# + 48*(2048*(128*32+128*4*2+128*32)+2048*768*8*3+2048*128)
|
| 419 |
+
# )*(512+1024+2048)
|
| 420 |
+
# + 12*(512*512+1024*1024+2048*2048)*48*4096
|
| 421 |
+
# + 6 * dense_N * (512 + 1024 + 2048)
|
| 422 |
+
# + 12 * (512**2 + 1024**2 + 2048**2) * 27 * 16 * 72
|
| 423 |
+
#
|
| 424 |
+
# 6*(151936*2048*2
|
| 425 |
+
# 48*(2048*(128*32+128*4*2+128*32)+2048*768*8*3+2048*128)
|
| 426 |
+
# )*(4096+4096+4096)
|
| 427 |
+
# + 12*(4096*4096+4096*4096+4096*4096)*48*4096
|
| 428 |
+
# + 6 * dense_N * (4096 + 4096 + 2048)
|
| 429 |
+
# + 12 * (4096**2 + 4096**2 + 4096**2) * 27 * 16 * 72
|
| 430 |
+
"expected_flops_tuple": (
|
| 431 |
+
92975451340800 / 1e12,
|
| 432 |
+
367622860308480 / 1e12,
|
| 433 |
+
),
|
| 434 |
+
},
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
@pytest.mark.parametrize(
|
| 439 |
+
"config_type",
|
| 440 |
+
[
|
| 441 |
+
"llama",
|
| 442 |
+
"qwen2",
|
| 443 |
+
"qwen3",
|
| 444 |
+
"qwen3_moe",
|
| 445 |
+
"deepseek_v3",
|
| 446 |
+
"mistral",
|
| 447 |
+
"gemma3_text",
|
| 448 |
+
"apertus",
|
| 449 |
+
"gpt_oss",
|
| 450 |
+
"qwen3_vl",
|
| 451 |
+
"qwen3_vl_moe",
|
| 452 |
+
],
|
| 453 |
+
)
|
| 454 |
+
def test_flops_counter(config_type: str):
|
| 455 |
+
test_config = CONFIG[config_type]
|
| 456 |
+
config = Config(test_config["config"])
|
| 457 |
+
flops_counter = FlopsCounter(config)
|
| 458 |
+
if "images_seqlens_tuple" in test_config:
|
| 459 |
+
for batch_seqlens, images_seqlens, expected_flops in zip(
|
| 460 |
+
test_config["batch_seqlens_tuple"],
|
| 461 |
+
test_config["images_seqlens_tuple"],
|
| 462 |
+
test_config["expected_flops_tuple"],
|
| 463 |
+
strict=True,
|
| 464 |
+
):
|
| 465 |
+
# set delta time to 1 to get the flops
|
| 466 |
+
counted_flops, _ = flops_counter.estimate_flops(batch_seqlens, 1, images_seqlens=images_seqlens)
|
| 467 |
+
print(f"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}")
|
| 468 |
+
assert math.isclose(counted_flops, expected_flops), (
|
| 469 |
+
f"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}"
|
| 470 |
+
)
|
| 471 |
+
else:
|
| 472 |
+
for batch_seqlens, expected_flops in zip(
|
| 473 |
+
test_config["batch_seqlens_tuple"], test_config["expected_flops_tuple"], strict=True
|
| 474 |
+
):
|
| 475 |
+
# set delta time to 1 to get the flops
|
| 476 |
+
counted_flops, _ = flops_counter.estimate_flops(batch_seqlens, 1)
|
| 477 |
+
print(f"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}")
|
| 478 |
+
assert math.isclose(counted_flops, expected_flops), (
|
| 479 |
+
f"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}"
|
| 480 |
+
)
|
code/RL_model/verl/verl_train/tests/utils/test_fs_on_cpu.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
import verl.utils.fs as fs
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def test_record_and_check_directory_structure(tmp_path):
|
| 22 |
+
# Create test directory structure
|
| 23 |
+
test_dir = tmp_path / "test_dir"
|
| 24 |
+
test_dir.mkdir()
|
| 25 |
+
(test_dir / "file1.txt").write_text("test")
|
| 26 |
+
(test_dir / "subdir").mkdir()
|
| 27 |
+
(test_dir / "subdir" / "file2.txt").write_text("test")
|
| 28 |
+
|
| 29 |
+
# Create structure record
|
| 30 |
+
record_file = fs._record_directory_structure(test_dir)
|
| 31 |
+
|
| 32 |
+
# Verify record file exists
|
| 33 |
+
assert os.path.exists(record_file)
|
| 34 |
+
|
| 35 |
+
# Initial check should pass
|
| 36 |
+
assert fs._check_directory_structure(test_dir, record_file) is True
|
| 37 |
+
|
| 38 |
+
# Modify structure and verify check fails
|
| 39 |
+
(test_dir / "new_file.txt").write_text("test")
|
| 40 |
+
assert fs._check_directory_structure(test_dir, record_file) is False
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def test_copy_from_hdfs_with_mocks(tmp_path, monkeypatch):
|
| 44 |
+
# Mock HDFS dependencies
|
| 45 |
+
monkeypatch.setattr(fs, "is_non_local", lambda path: True)
|
| 46 |
+
|
| 47 |
+
# side_effect will simulate the copy by creating parent dirs + empty file
|
| 48 |
+
def fake_copy(src: str, dst: str, *args, **kwargs):
|
| 49 |
+
dst_path = Path(dst)
|
| 50 |
+
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
| 51 |
+
dst_path.write_bytes(b"") # touch an empty file
|
| 52 |
+
|
| 53 |
+
monkeypatch.setattr(fs, "copy", fake_copy) # Mock actual HDFS copy
|
| 54 |
+
|
| 55 |
+
# Test parameters
|
| 56 |
+
test_cache = tmp_path / "cache"
|
| 57 |
+
hdfs_path = "hdfs://test/path/file.txt"
|
| 58 |
+
|
| 59 |
+
# Test initial copy
|
| 60 |
+
local_path = fs.copy_to_local(hdfs_path, cache_dir=test_cache)
|
| 61 |
+
expected_path = os.path.join(test_cache, fs.md5_encode(hdfs_path), os.path.basename(hdfs_path))
|
| 62 |
+
assert local_path == expected_path
|
| 63 |
+
assert os.path.exists(local_path)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def test_always_recopy_flag(tmp_path, monkeypatch):
|
| 67 |
+
# Mock HDFS dependencies
|
| 68 |
+
monkeypatch.setattr(fs, "is_non_local", lambda path: True)
|
| 69 |
+
|
| 70 |
+
copy_call_count = 0
|
| 71 |
+
|
| 72 |
+
def fake_copy(src: str, dst: str, *args, **kwargs):
|
| 73 |
+
nonlocal copy_call_count
|
| 74 |
+
copy_call_count += 1
|
| 75 |
+
dst_path = Path(dst)
|
| 76 |
+
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
| 77 |
+
dst_path.write_bytes(b"")
|
| 78 |
+
|
| 79 |
+
monkeypatch.setattr(fs, "copy", fake_copy) # Mock actual HDFS copy
|
| 80 |
+
|
| 81 |
+
test_cache = tmp_path / "cache"
|
| 82 |
+
hdfs_path = "hdfs://test/path/file.txt"
|
| 83 |
+
|
| 84 |
+
# Initial copy (always_recopy=False)
|
| 85 |
+
fs.copy_to_local(hdfs_path, cache_dir=test_cache)
|
| 86 |
+
assert copy_call_count == 1
|
| 87 |
+
|
| 88 |
+
# Force recopy (always_recopy=True)
|
| 89 |
+
fs.copy_to_local(hdfs_path, cache_dir=test_cache, always_recopy=True)
|
| 90 |
+
assert copy_call_count == 2
|
| 91 |
+
|
| 92 |
+
# Subsequent normal call (always_recopy=False)
|
| 93 |
+
fs.copy_to_local(hdfs_path, cache_dir=test_cache)
|
| 94 |
+
assert copy_call_count == 2 # Should not increment
|
code/RL_model/verl/verl_train/tests/utils/test_groupwise.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Copyright 2023-2024 SGLang Team
|
| 3 |
+
# Copyright 2025 ModelBest Inc. and/or its affiliates
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
import os
|
| 17 |
+
|
| 18 |
+
os.environ.setdefault("VERL_FORCE_DEVICE", "cpu") # ensure CPU for tests
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import pytest
|
| 22 |
+
import torch
|
| 23 |
+
|
| 24 |
+
from verl.utils import as_torch_index, group_mean_std
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def test_as_torch_index_basic_integers():
|
| 28 |
+
g = as_torch_index([2, 2, 5, 7, 5, 2])
|
| 29 |
+
assert g.dtype == torch.long
|
| 30 |
+
assert g.device.type == "cpu"
|
| 31 |
+
# Values should be contiguous 0..G-1, keeping equal labels equal
|
| 32 |
+
assert g.tolist()[0] == g.tolist()[1]
|
| 33 |
+
assert len(torch.unique(g)) == 3 # {2,5,7} -> 3 groups
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def test_as_torch_index_near_integer_floats():
|
| 37 |
+
arr = np.array([1.0000001, 2.0, 1.0, 3.0000000001], dtype=np.float64)
|
| 38 |
+
g = as_torch_index(arr) # should round to integers then factorize
|
| 39 |
+
assert g.dtype == torch.long
|
| 40 |
+
assert len(torch.unique(g)) == 3 # {1,2,3}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def test_as_torch_index_factorization_mixed():
|
| 44 |
+
labels = ["a", "b", "a", "c", "0042", 42]
|
| 45 |
+
g = as_torch_index(labels)
|
| 46 |
+
# "0042" and 42 should NOT be the same group (strings are not coerced here)
|
| 47 |
+
assert g.tolist()[4] != g.tolist()[5]
|
| 48 |
+
assert len(torch.unique(g)) == 5
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def test_group_mean_std_simple():
|
| 52 |
+
# groups: 0 -> [1, 3], 1 -> [2]
|
| 53 |
+
scores = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
|
| 54 |
+
gidx = as_torch_index([0, 1, 0])
|
| 55 |
+
|
| 56 |
+
mean_g, std_g, cnt_g = group_mean_std(scores, gidx)
|
| 57 |
+
# group 0: mean = (1+3)/2 = 2
|
| 58 |
+
# sample std (unbiased) = sqrt( (sum(x^2) - (sum(x)^2)/n) / (n-1) )
|
| 59 |
+
# = sqrt( (1^2+3^2) - (1+3)^2/2 ) / (2-1) = sqrt(10 - 16/2) = sqrt(2)
|
| 60 |
+
assert torch.allclose(mean_g, torch.tensor([2.0, 0.0]))
|
| 61 |
+
assert torch.allclose(cnt_g, torch.tensor([2.0, 1.0]))
|
| 62 |
+
# singleton group -> std = 1.0
|
| 63 |
+
assert mean_g[1].item() == 0.0
|
| 64 |
+
assert std_g[1].item() == 1.0
|
| 65 |
+
assert pytest.approx(std_g[0].item(), rel=1e-6) == (2.0**0.5)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def test_group_mean_std_empty():
|
| 69 |
+
scores = torch.tensor([], dtype=torch.float32)
|
| 70 |
+
gidx = torch.tensor([], dtype=torch.long)
|
| 71 |
+
mean_g, std_g, cnt_g = group_mean_std(scores, gidx)
|
| 72 |
+
assert mean_g.numel() == 0 and std_g.numel() == 0 and cnt_g.numel() == 0
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def test_group_mean_std_default_device_no_force_env(monkeypatch):
|
| 76 |
+
"""
|
| 77 |
+
Regression test:
|
| 78 |
+
- group_mean_std(device=None) must not pass a device *module* (e.g., torch.cuda)
|
| 79 |
+
into Tensor.to(device=...), which crashes with:
|
| 80 |
+
TypeError: to() received an invalid combination of arguments - got (..., device=module, ...)
|
| 81 |
+
"""
|
| 82 |
+
# Simulate a non-pytest environment (training code path) while keeping the test CPU-only.
|
| 83 |
+
monkeypatch.delenv("VERL_FORCE_DEVICE", raising=False)
|
| 84 |
+
monkeypatch.delenv("PYTEST_CURRENT_TEST", raising=False)
|
| 85 |
+
|
| 86 |
+
# Force device selection to CPU even if CUDA is available on the test machine.
|
| 87 |
+
import verl.utils.device as device_mod
|
| 88 |
+
|
| 89 |
+
monkeypatch.setattr(device_mod, "is_cuda_available", False)
|
| 90 |
+
monkeypatch.setattr(device_mod, "is_npu_available", False)
|
| 91 |
+
|
| 92 |
+
scores = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
|
| 93 |
+
gidx = torch.tensor([0, 1, 0], dtype=torch.long)
|
| 94 |
+
|
| 95 |
+
mean_g, std_g, cnt_g = group_mean_std(scores, gidx)
|
| 96 |
+
assert mean_g.device.type == "cpu"
|
| 97 |
+
assert std_g.device.type == "cpu"
|
| 98 |
+
assert cnt_g.device.type == "cpu"
|
code/RL_model/verl/verl_train/tests/utils/test_import_utils_on_cpu.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
import pytest
|
| 18 |
+
|
| 19 |
+
from verl.utils.import_utils import load_extern_object
|
| 20 |
+
|
| 21 |
+
# Path to the test module
|
| 22 |
+
TEST_MODULE_PATH = os.path.join(os.path.dirname(__file__), "_test_module.py")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def test_load_extern_object_class():
|
| 26 |
+
"""Test loading a class from an external file"""
|
| 27 |
+
TestClass = load_extern_object(TEST_MODULE_PATH, "TestClass")
|
| 28 |
+
|
| 29 |
+
# Verify the class was loaded correctly
|
| 30 |
+
assert TestClass is not None
|
| 31 |
+
assert TestClass.__name__ == "TestClass"
|
| 32 |
+
|
| 33 |
+
# Test instantiation and functionality
|
| 34 |
+
instance = TestClass()
|
| 35 |
+
assert instance.value == "default"
|
| 36 |
+
|
| 37 |
+
# Test with a custom value
|
| 38 |
+
custom_instance = TestClass("custom")
|
| 39 |
+
assert custom_instance.get_value() == "custom"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def test_load_extern_object_function():
|
| 43 |
+
"""Test loading a function from an external file"""
|
| 44 |
+
test_function = load_extern_object(TEST_MODULE_PATH, "test_function")
|
| 45 |
+
|
| 46 |
+
# Verify the function was loaded correctly
|
| 47 |
+
assert test_function is not None
|
| 48 |
+
assert callable(test_function)
|
| 49 |
+
|
| 50 |
+
# Test function execution
|
| 51 |
+
result = test_function()
|
| 52 |
+
assert result == "test_function_result"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def test_load_extern_object_constant():
|
| 56 |
+
"""Test loading a constant from an external file"""
|
| 57 |
+
constant = load_extern_object(TEST_MODULE_PATH, "TEST_CONSTANT")
|
| 58 |
+
|
| 59 |
+
# Verify the constant was loaded correctly
|
| 60 |
+
assert constant is not None
|
| 61 |
+
assert constant == "test_constant_value"
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def test_load_extern_object_nonexistent_file():
|
| 65 |
+
"""Test behavior when file doesn't exist"""
|
| 66 |
+
with pytest.raises(FileNotFoundError):
|
| 67 |
+
load_extern_object("/nonexistent/path.py", "SomeType")
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def test_load_extern_object_nonexistent_type():
|
| 71 |
+
"""Test behavior when type doesn't exist in the file"""
|
| 72 |
+
with pytest.raises(AttributeError):
|
| 73 |
+
load_extern_object(TEST_MODULE_PATH, "NonExistentType")
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def test_load_extern_object_none_path():
|
| 77 |
+
"""Test behavior when file path is None"""
|
| 78 |
+
with pytest.raises(AttributeError):
|
| 79 |
+
load_extern_object(None, "SomeType")
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def test_load_extern_object_invalid_module():
|
| 83 |
+
"""Test behavior when module has syntax errors"""
|
| 84 |
+
# Create a temporary file with syntax errors
|
| 85 |
+
import tempfile
|
| 86 |
+
|
| 87 |
+
with tempfile.NamedTemporaryFile(suffix=".py", mode="w+", delete=False) as temp_file:
|
| 88 |
+
temp_file.write("This is not valid Python syntax :")
|
| 89 |
+
temp_path = temp_file.name
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
with pytest.raises(RuntimeError):
|
| 93 |
+
load_extern_object(temp_path, "SomeType")
|
| 94 |
+
finally:
|
| 95 |
+
# Clean up the temporary file
|
| 96 |
+
if os.path.exists(temp_path):
|
| 97 |
+
os.remove(temp_path)
|
code/RL_model/verl/verl_train/tests/utils/test_linear_cross_entropy.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
#
|
| 17 |
+
|
| 18 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 19 |
+
#
|
| 20 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 21 |
+
# you may not use this file except in compliance with the License.
|
| 22 |
+
# You may obtain a copy of the License at
|
| 23 |
+
#
|
| 24 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 25 |
+
#
|
| 26 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 27 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 28 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 29 |
+
# See the License for the specific language governing permissions and
|
| 30 |
+
# limitations under the License.
|
| 31 |
+
|
| 32 |
+
import os
|
| 33 |
+
|
| 34 |
+
import torch
|
| 35 |
+
|
| 36 |
+
import verl.utils.torch_functional as verl_F
|
| 37 |
+
from verl.utils.experimental.torch_functional import FusedLinearForPPO
|
| 38 |
+
from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy
|
| 39 |
+
from verl.utils.torch_functional import logprobs_from_logits
|
| 40 |
+
|
| 41 |
+
compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True)
|
| 42 |
+
fused_linear_for_ppo = FusedLinearForPPO()
|
| 43 |
+
fused_linear_for_ppo.compile(dynamic=True)
|
| 44 |
+
|
| 45 |
+
MAX_TEST_CASES = os.environ.get("MAX_TEST_CASES", 5)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def run_torch_entropy(
|
| 49 |
+
hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, temperature: float, reduction="none"
|
| 50 |
+
) -> list[torch.Tensor]:
|
| 51 |
+
hidden = hidden.squeeze(0).to(torch.float32)
|
| 52 |
+
weight = weight.transpose(0, 1).to(torch.float32)
|
| 53 |
+
logits = torch.matmul(hidden, weight) # [num_tokens, vocab_size]
|
| 54 |
+
logits /= temperature
|
| 55 |
+
pd = torch.nn.functional.softmax(logits, dim=-1) # [num_tokens, vocab_size]
|
| 56 |
+
entropy_a = torch.logsumexp(logits, dim=-1) # [num_tokens]
|
| 57 |
+
entropy_b = torch.sum(pd * logits, dim=-1) # [num_tokens]
|
| 58 |
+
entropy = entropy_a - entropy_b
|
| 59 |
+
logprobs = torch.nn.functional.cross_entropy(logits, labels.squeeze(0), reduction=reduction) # [num_tokens]
|
| 60 |
+
logprobs = torch.neg(logprobs)
|
| 61 |
+
return logprobs, entropy
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def run_verl_original_entropy(
|
| 65 |
+
hidden: torch.Tensor,
|
| 66 |
+
weight: torch.Tensor,
|
| 67 |
+
labels: torch.Tensor,
|
| 68 |
+
temperature: float,
|
| 69 |
+
) -> list[torch.Tensor]:
|
| 70 |
+
hidden = hidden.squeeze(0).to(torch.float32)
|
| 71 |
+
weight = weight.transpose(0, 1).to(torch.float32)
|
| 72 |
+
logits = torch.matmul(hidden, weight) # [num_tokens, vocab_size]
|
| 73 |
+
logits /= temperature
|
| 74 |
+
# compute entropy
|
| 75 |
+
entropy = compute_entropy_from_logits(logits) # ((total_nnz / sp) + pad)
|
| 76 |
+
# if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen)
|
| 77 |
+
logprobs = logprobs_from_logits(logits=logits, labels=labels, inplace_backward=False)
|
| 78 |
+
return logprobs, entropy
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# To be tested
|
| 82 |
+
def run_verl_torch_fused_entropy(
|
| 83 |
+
hidden: torch.Tensor,
|
| 84 |
+
weight: torch.Tensor,
|
| 85 |
+
labels: torch.Tensor,
|
| 86 |
+
temperature: float,
|
| 87 |
+
):
|
| 88 |
+
hidden = hidden.to(torch.float32)
|
| 89 |
+
weight = weight.to(torch.float32)
|
| 90 |
+
logprobs, entropy = fused_linear_for_ppo(
|
| 91 |
+
hidden,
|
| 92 |
+
weight,
|
| 93 |
+
labels,
|
| 94 |
+
temperature=temperature,
|
| 95 |
+
)
|
| 96 |
+
return logprobs.squeeze(0), entropy.squeeze(0)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class TestLinearCrossEntropy:
|
| 100 |
+
def __init__(self, test_case_idx: int, temperature: float = 1.5) -> None:
|
| 101 |
+
self.test_case_idx = test_case_idx
|
| 102 |
+
self.temperature = temperature
|
| 103 |
+
|
| 104 |
+
def cleanup(self):
|
| 105 |
+
torch.cuda.empty_cache()
|
| 106 |
+
torch.cuda.reset_peak_memory_stats()
|
| 107 |
+
import gc
|
| 108 |
+
|
| 109 |
+
gc.collect()
|
| 110 |
+
torch.cuda.synchronize()
|
| 111 |
+
|
| 112 |
+
def generate_hyper(self):
|
| 113 |
+
global MAX_TEST_CASES
|
| 114 |
+
|
| 115 |
+
self.dtype = torch.bfloat16
|
| 116 |
+
if self.test_case_idx == 0:
|
| 117 |
+
self.batch_size = 1
|
| 118 |
+
self.num_tokens = 1937
|
| 119 |
+
self.hidden_size = 3584
|
| 120 |
+
self.vocab_size = 152064
|
| 121 |
+
elif self.test_case_idx == 1:
|
| 122 |
+
self.batch_size = 1
|
| 123 |
+
self.num_tokens = 2169
|
| 124 |
+
self.hidden_size = 896
|
| 125 |
+
self.vocab_size = 151936
|
| 126 |
+
elif self.test_case_idx == 2:
|
| 127 |
+
self.batch_size = 1
|
| 128 |
+
self.num_tokens = 1530
|
| 129 |
+
self.hidden_size = 2048
|
| 130 |
+
self.vocab_size = 32256
|
| 131 |
+
elif self.test_case_idx == 3:
|
| 132 |
+
self.batch_size = 1
|
| 133 |
+
self.num_tokens = 1388
|
| 134 |
+
self.hidden_size = 4096
|
| 135 |
+
self.vocab_size = 102400
|
| 136 |
+
elif self.test_case_idx == 4:
|
| 137 |
+
self.batch_size = 1
|
| 138 |
+
self.num_tokens = 8192
|
| 139 |
+
self.hidden_size = 4096
|
| 140 |
+
self.vocab_size = 102400
|
| 141 |
+
else:
|
| 142 |
+
raise ValueError(f"Invalid test case index: {self.test_case_idx}")
|
| 143 |
+
assert MAX_TEST_CASES <= 5, "MAX_TEST_CASES should be less than or equal to 5."
|
| 144 |
+
|
| 145 |
+
def generate_forward_inputs(self):
|
| 146 |
+
hidden = (
|
| 147 |
+
torch.empty((self.batch_size, self.num_tokens, self.hidden_size), dtype=self.dtype, device="cuda")
|
| 148 |
+
.uniform_(-0.5, 0.5)
|
| 149 |
+
.requires_grad_()
|
| 150 |
+
)
|
| 151 |
+
weight = (
|
| 152 |
+
torch.empty((self.vocab_size, self.hidden_size), dtype=self.dtype, device="cuda")
|
| 153 |
+
.uniform_(-0.5, 0.5)
|
| 154 |
+
.requires_grad_()
|
| 155 |
+
)
|
| 156 |
+
labels = torch.randint(0, self.vocab_size, (self.batch_size, self.num_tokens), device="cuda")
|
| 157 |
+
return hidden, weight, labels
|
| 158 |
+
|
| 159 |
+
def generate_backward_inputs(self):
|
| 160 |
+
g_entropy = torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-0.5, 0.5)
|
| 161 |
+
g_logprobs = torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-1, 1)
|
| 162 |
+
return g_entropy, g_logprobs
|
| 163 |
+
|
| 164 |
+
def verify_correctness(self, iterations=5):
|
| 165 |
+
self.cleanup()
|
| 166 |
+
self.generate_hyper()
|
| 167 |
+
|
| 168 |
+
torch_forward_latency = list()
|
| 169 |
+
torch_backward_latency = list()
|
| 170 |
+
verl_forward_latency = list()
|
| 171 |
+
verl_backward_latency = list()
|
| 172 |
+
verl_fused_forward_latency = list()
|
| 173 |
+
verl_fused_backward_latency = list()
|
| 174 |
+
kernel_forward_latency = list()
|
| 175 |
+
kernel_backward_latency = list()
|
| 176 |
+
|
| 177 |
+
start_event = torch.cuda.Event(enable_timing=True)
|
| 178 |
+
end_event = torch.cuda.Event(enable_timing=True)
|
| 179 |
+
|
| 180 |
+
for i in range(iterations):
|
| 181 |
+
print(f"[INFO]: Iteration {i + 1} / {iterations}...", end="\r")
|
| 182 |
+
hidden, weight, labels = self.generate_forward_inputs()
|
| 183 |
+
|
| 184 |
+
start_event.record()
|
| 185 |
+
(torch_logprobs, torch_entropy) = run_torch_entropy(hidden, weight, labels, self.temperature)
|
| 186 |
+
end_event.record()
|
| 187 |
+
torch.cuda.synchronize()
|
| 188 |
+
torch_forward_latency.append(start_event.elapsed_time(end_event))
|
| 189 |
+
|
| 190 |
+
start_event.record()
|
| 191 |
+
(verl_logprobs, verl_entropy) = run_verl_original_entropy(hidden, weight, labels, self.temperature)
|
| 192 |
+
end_event.record()
|
| 193 |
+
torch.cuda.synchronize()
|
| 194 |
+
verl_forward_latency.append(start_event.elapsed_time(end_event))
|
| 195 |
+
|
| 196 |
+
start_event.record()
|
| 197 |
+
(verl_fused_logprobs, verl_fused_entropy) = run_verl_torch_fused_entropy(
|
| 198 |
+
hidden, weight, labels, self.temperature
|
| 199 |
+
)
|
| 200 |
+
end_event.record()
|
| 201 |
+
torch.cuda.synchronize()
|
| 202 |
+
verl_fused_forward_latency.append(start_event.elapsed_time(end_event))
|
| 203 |
+
|
| 204 |
+
start_event.record()
|
| 205 |
+
(kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels, self.temperature)
|
| 206 |
+
end_event.record()
|
| 207 |
+
torch.cuda.synchronize()
|
| 208 |
+
kernel_forward_latency.append(start_event.elapsed_time(end_event))
|
| 209 |
+
|
| 210 |
+
torch.testing.assert_close(torch_logprobs, verl_logprobs, atol=1e-4, rtol=1e-4)
|
| 211 |
+
torch.testing.assert_close(torch_entropy, verl_entropy, atol=1e-4, rtol=1e-4)
|
| 212 |
+
|
| 213 |
+
torch.testing.assert_close(torch_logprobs, verl_fused_logprobs, atol=1e-4, rtol=1e-4)
|
| 214 |
+
torch.testing.assert_close(torch_entropy, verl_fused_entropy, atol=1e-4, rtol=1e-4)
|
| 215 |
+
torch.testing.assert_close(verl_logprobs, verl_fused_logprobs, atol=1e-4, rtol=1e-4)
|
| 216 |
+
torch.testing.assert_close(verl_entropy, verl_fused_entropy, atol=1e-4, rtol=1e-4)
|
| 217 |
+
|
| 218 |
+
torch.testing.assert_close(torch_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4)
|
| 219 |
+
torch.testing.assert_close(torch_entropy, kernel_entropy, atol=5e-3, rtol=5e-4)
|
| 220 |
+
torch.testing.assert_close(verl_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4)
|
| 221 |
+
torch.testing.assert_close(verl_entropy, kernel_entropy, atol=5e-3, rtol=5e-4)
|
| 222 |
+
torch.testing.assert_close(verl_fused_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4)
|
| 223 |
+
torch.testing.assert_close(verl_fused_entropy, kernel_entropy, atol=5e-3, rtol=5e-4)
|
| 224 |
+
|
| 225 |
+
# backward
|
| 226 |
+
g_entropy, g_logprobs = self.generate_backward_inputs()
|
| 227 |
+
|
| 228 |
+
start_event.record()
|
| 229 |
+
(d_torch_hidden, d_torch_weight) = torch.autograd.grad(
|
| 230 |
+
(torch_entropy, torch_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
|
| 231 |
+
)
|
| 232 |
+
end_event.record()
|
| 233 |
+
torch.cuda.synchronize()
|
| 234 |
+
torch_backward_latency.append(start_event.elapsed_time(end_event))
|
| 235 |
+
|
| 236 |
+
start_event.record()
|
| 237 |
+
(d_verl_hidden, d_verl_weight) = torch.autograd.grad(
|
| 238 |
+
(verl_entropy, verl_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
|
| 239 |
+
)
|
| 240 |
+
end_event.record()
|
| 241 |
+
torch.cuda.synchronize()
|
| 242 |
+
verl_backward_latency.append(start_event.elapsed_time(end_event))
|
| 243 |
+
|
| 244 |
+
start_event.record()
|
| 245 |
+
(d_verl_fused_hidden, d_verl_fused_weight) = torch.autograd.grad(
|
| 246 |
+
(verl_fused_entropy, verl_fused_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
|
| 247 |
+
)
|
| 248 |
+
end_event.record()
|
| 249 |
+
torch.cuda.synchronize()
|
| 250 |
+
verl_fused_backward_latency.append(start_event.elapsed_time(end_event))
|
| 251 |
+
|
| 252 |
+
start_event.record()
|
| 253 |
+
(d_kernel_hidden, d_kernel_weight) = torch.autograd.grad(
|
| 254 |
+
(kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
|
| 255 |
+
)
|
| 256 |
+
end_event.record()
|
| 257 |
+
torch.cuda.synchronize()
|
| 258 |
+
kernel_backward_latency.append(start_event.elapsed_time(end_event))
|
| 259 |
+
|
| 260 |
+
torch.testing.assert_close(d_torch_hidden, d_verl_hidden, atol=1e-2, rtol=1e-4)
|
| 261 |
+
torch.testing.assert_close(d_torch_weight, d_verl_weight, atol=1e-2, rtol=1e-4)
|
| 262 |
+
|
| 263 |
+
torch.testing.assert_close(d_torch_hidden, d_verl_fused_hidden, atol=1e-2, rtol=1e-4)
|
| 264 |
+
torch.testing.assert_close(d_torch_weight, d_verl_fused_weight, atol=1e-2, rtol=1e-4)
|
| 265 |
+
torch.testing.assert_close(d_verl_hidden, d_verl_fused_hidden, atol=1e-2, rtol=1e-4)
|
| 266 |
+
torch.testing.assert_close(d_verl_weight, d_verl_fused_weight, atol=1e-2, rtol=1e-4)
|
| 267 |
+
torch.testing.assert_close(d_torch_hidden, d_verl_hidden, atol=1e-2, rtol=1e-4)
|
| 268 |
+
torch.testing.assert_close(d_torch_weight, d_verl_weight, atol=1e-2, rtol=1e-4)
|
| 269 |
+
|
| 270 |
+
torch.testing.assert_close(d_torch_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2)
|
| 271 |
+
torch.testing.assert_close(d_torch_weight, d_kernel_weight, atol=2e-2, rtol=4e-2)
|
| 272 |
+
torch.testing.assert_close(d_verl_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2)
|
| 273 |
+
torch.testing.assert_close(d_verl_weight, d_kernel_weight, atol=2e-2, rtol=4e-2)
|
| 274 |
+
torch.testing.assert_close(d_verl_fused_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2)
|
| 275 |
+
torch.testing.assert_close(d_verl_fused_weight, d_kernel_weight, atol=2e-2, rtol=4e-2)
|
| 276 |
+
|
| 277 |
+
# remove first latency
|
| 278 |
+
torch_forward_latency = torch_forward_latency[1:]
|
| 279 |
+
torch_backward_latency = torch_backward_latency[1:]
|
| 280 |
+
verl_forward_latency = verl_forward_latency[1:]
|
| 281 |
+
verl_backward_latency = verl_backward_latency[1:]
|
| 282 |
+
verl_fused_forward_latency = verl_fused_forward_latency[1:]
|
| 283 |
+
verl_fused_backward_latency = verl_fused_backward_latency[1:]
|
| 284 |
+
kernel_forward_latency = kernel_forward_latency[1:]
|
| 285 |
+
kernel_backward_latency = kernel_backward_latency[1:]
|
| 286 |
+
|
| 287 |
+
print("\n[INFO]: Verified forward & backward correctness.")
|
| 288 |
+
|
| 289 |
+
print(
|
| 290 |
+
f"[INFO]: Forward pass: Torch implementation average time: "
|
| 291 |
+
f"{sum(torch_forward_latency) / len(torch_forward_latency):.2f} ms"
|
| 292 |
+
)
|
| 293 |
+
print(
|
| 294 |
+
f"[INFO]: Backward pass: torch implementation average time: "
|
| 295 |
+
f"{sum(torch_backward_latency) / len(torch_backward_latency):.2f} ms"
|
| 296 |
+
)
|
| 297 |
+
print(
|
| 298 |
+
f"[INFO]: Forward pass: VeRL implementation average time: "
|
| 299 |
+
f"{sum(verl_forward_latency) / len(verl_forward_latency):.2f} ms"
|
| 300 |
+
)
|
| 301 |
+
print(
|
| 302 |
+
f"[INFO]: Backward pass: VeRL implementation average time: "
|
| 303 |
+
f"{sum(verl_backward_latency) / len(verl_backward_latency):.2f} ms"
|
| 304 |
+
)
|
| 305 |
+
print(
|
| 306 |
+
f"[INFO]: Forward pass: VeRL Fused Entropy implementation average time: "
|
| 307 |
+
f"{sum(verl_fused_forward_latency) / len(verl_fused_forward_latency):.2f} ms"
|
| 308 |
+
)
|
| 309 |
+
print(
|
| 310 |
+
f"[INFO]: Backward pass: VeRL Fused Entropy implementation average time: "
|
| 311 |
+
f"{sum(verl_fused_backward_latency) / len(verl_fused_backward_latency):.2f} ms"
|
| 312 |
+
)
|
| 313 |
+
print(
|
| 314 |
+
f"[INFO]: Forward pass: Kernel implementation average time: "
|
| 315 |
+
f"{sum(kernel_forward_latency) / len(kernel_forward_latency):.2f} ms"
|
| 316 |
+
)
|
| 317 |
+
print(
|
| 318 |
+
f"[INFO]: Backward pass: kernel implementation average time: "
|
| 319 |
+
f"{sum(kernel_backward_latency) / len(kernel_backward_latency):.2f} ms"
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
def check_storage(self, method_name, run_forward):
|
| 323 |
+
self.cleanup()
|
| 324 |
+
self.generate_hyper()
|
| 325 |
+
|
| 326 |
+
hidden, weight, labels = self.generate_forward_inputs()
|
| 327 |
+
|
| 328 |
+
torch.cuda.reset_peak_memory_stats()
|
| 329 |
+
(logprobs, entropy) = run_forward(hidden, weight, labels, self.temperature)
|
| 330 |
+
torch.cuda.synchronize()
|
| 331 |
+
torch_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024
|
| 332 |
+
print(f"[INFO]: {method_name} Forward pass peak memory: {torch_max_memory:.2f} MB")
|
| 333 |
+
|
| 334 |
+
g_entropy, g_logprobs = self.generate_backward_inputs()
|
| 335 |
+
|
| 336 |
+
torch.cuda.reset_peak_memory_stats()
|
| 337 |
+
(d_torch_hidden, d_torch_weight) = torch.autograd.grad(
|
| 338 |
+
(entropy, logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
|
| 339 |
+
)
|
| 340 |
+
torch.cuda.synchronize()
|
| 341 |
+
torch_backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024
|
| 342 |
+
print(f"[INFO]: {method_name} Backward pass peak memory: {torch_backward_max_memory:.2f} MB")
|
| 343 |
+
|
| 344 |
+
def check_storage_all(self):
|
| 345 |
+
self.check_storage("Torch", run_torch_entropy)
|
| 346 |
+
self.check_storage("VeRL", run_verl_original_entropy)
|
| 347 |
+
self.check_storage("VeRL Torch Fused", run_verl_torch_fused_entropy)
|
| 348 |
+
self.check_storage("Kernel", linear_cross_entropy)
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
if __name__ == "__main__":
|
| 352 |
+
# torch.cuda.memory._record_memory_history()
|
| 353 |
+
|
| 354 |
+
for test_case_idx in range(MAX_TEST_CASES):
|
| 355 |
+
print(f"[INFO] Running test case {test_case_idx}")
|
| 356 |
+
test = TestLinearCrossEntropy(test_case_idx)
|
| 357 |
+
|
| 358 |
+
test.verify_correctness()
|
| 359 |
+
test.check_storage_all()
|
| 360 |
+
|
| 361 |
+
# torch.cuda.memory._dump_snapshot("test_linear_cross_entropy.pkl")
|
code/RL_model/verl/verl_train/tests/utils/test_mlflow_key_sanitization.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import unittest
|
| 16 |
+
from unittest.mock import patch
|
| 17 |
+
|
| 18 |
+
from verl.utils.tracking import _MlflowLoggingAdapter
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class TestMlflowLoggingAdapter(unittest.TestCase):
|
| 22 |
+
def test_sanitize_key_and_warning(self):
|
| 23 |
+
"""Test key sanitization for invalid characters and consecutive slashes with warnings."""
|
| 24 |
+
adapter = _MlflowLoggingAdapter()
|
| 25 |
+
data = {
|
| 26 |
+
"valid_key": 1.0,
|
| 27 |
+
"invalid@key!": 2.0,
|
| 28 |
+
"another/valid-key": 3.0,
|
| 29 |
+
"bad key#": 4.0,
|
| 30 |
+
"val-aux//reward/mean_at_1": 5.0,
|
| 31 |
+
"val-core///acc/best_at_5": 6.0,
|
| 32 |
+
"metric////with/many////slashes": 7.0,
|
| 33 |
+
}
|
| 34 |
+
# Patch mlflow.log_metrics to capture the metrics actually sent
|
| 35 |
+
with (
|
| 36 |
+
patch("mlflow.log_metrics") as mock_log_metrics,
|
| 37 |
+
patch.object(adapter, "logger") as mock_logger,
|
| 38 |
+
):
|
| 39 |
+
adapter.log(data, step=5)
|
| 40 |
+
# Check that invalid characters are sanitized
|
| 41 |
+
sent_metrics = mock_log_metrics.call_args[1]["metrics"]
|
| 42 |
+
self.assertIn("invalid_at_key_", sent_metrics) # @ becomes _at_, ! becomes _
|
| 43 |
+
self.assertIn("bad key_", sent_metrics) # # becomes _, space remains
|
| 44 |
+
self.assertNotIn("invalid@key!", sent_metrics)
|
| 45 |
+
self.assertNotIn("bad key#", sent_metrics)
|
| 46 |
+
# Check that consecutive slashes are collapsed to single slashes
|
| 47 |
+
self.assertIn("val-aux/reward/mean_at_1", sent_metrics)
|
| 48 |
+
self.assertIn("val-core/acc/best_at_5", sent_metrics)
|
| 49 |
+
self.assertIn("metric/with/many/slashes", sent_metrics)
|
| 50 |
+
self.assertNotIn("val-aux//reward/mean_at_1", sent_metrics)
|
| 51 |
+
self.assertNotIn("val-core///acc/best_at_5", sent_metrics)
|
| 52 |
+
# Check that warnings were logged for all sanitized keys
|
| 53 |
+
warning_msgs = [str(call) for call in mock_logger.warning.call_args_list]
|
| 54 |
+
# Warnings for invalid characters
|
| 55 |
+
self.assertTrue(any("invalid@key!" in msg and "invalid_at_key_" in msg for msg in warning_msgs))
|
| 56 |
+
self.assertTrue(any("bad key#" in msg and "bad key_" in msg for msg in warning_msgs))
|
| 57 |
+
# Warnings for consecutive slashes
|
| 58 |
+
self.assertTrue(any("val-aux//reward/mean_at_1" in msg for msg in warning_msgs))
|
| 59 |
+
self.assertTrue(any("val-core///acc/best_at_5" in msg for msg in warning_msgs))
|
| 60 |
+
self.assertTrue(any("metric////with/many////slashes" in msg for msg in warning_msgs))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
if __name__ == "__main__":
|
| 64 |
+
unittest.main()
|
code/RL_model/verl/verl_train/tests/utils/test_model_on_cpu.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from types import SimpleNamespace # Or use a mock object library
|
| 16 |
+
|
| 17 |
+
import pytest
|
| 18 |
+
|
| 19 |
+
from verl.utils.model import update_model_config
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Parametrize with different override scenarios
|
| 23 |
+
@pytest.mark.parametrize(
|
| 24 |
+
"override_kwargs",
|
| 25 |
+
[
|
| 26 |
+
{"param_a": 5, "new_param": "plain_added"},
|
| 27 |
+
{"param_a": 2, "nested_params": {"sub_param_x": "updated_x", "sub_param_z": True}},
|
| 28 |
+
],
|
| 29 |
+
)
|
| 30 |
+
def test_update_model_config(override_kwargs):
|
| 31 |
+
"""
|
| 32 |
+
Tests that update_model_config correctly updates attributes,
|
| 33 |
+
handling both plain and nested overrides via parametrization.
|
| 34 |
+
"""
|
| 35 |
+
# Create a fresh mock config object for each test case
|
| 36 |
+
mock_config = SimpleNamespace(
|
| 37 |
+
param_a=1, nested_params=SimpleNamespace(sub_param_x="original_x", sub_param_y=100), other_param="keep_me"
|
| 38 |
+
)
|
| 39 |
+
# Apply the updates using the parametrized override_kwargs
|
| 40 |
+
update_model_config(mock_config, override_kwargs)
|
| 41 |
+
|
| 42 |
+
# Assertions to check if the config was updated correctly
|
| 43 |
+
if "nested_params" in override_kwargs: # Case 2: Nested override
|
| 44 |
+
override_nested = override_kwargs["nested_params"]
|
| 45 |
+
assert mock_config.nested_params.sub_param_x == override_nested["sub_param_x"], "Nested sub_param_x mismatch"
|
| 46 |
+
assert mock_config.nested_params.sub_param_y == 100, "Nested sub_param_y should be unchanged"
|
| 47 |
+
assert hasattr(mock_config.nested_params, "sub_param_z"), "Expected nested sub_param_z to be added"
|
| 48 |
+
assert mock_config.nested_params.sub_param_z == override_nested["sub_param_z"], "Value of sub_param_z mismatch"
|
| 49 |
+
else: # Case 1: Plain override (nested params untouched)
|
| 50 |
+
assert mock_config.nested_params.sub_param_x == "original_x", "Nested sub_param_x should be unchanged"
|
| 51 |
+
assert mock_config.nested_params.sub_param_y == 100, "Nested sub_param_y should be unchanged"
|
| 52 |
+
assert not hasattr(mock_config.nested_params, "sub_param_z"), "Nested sub_param_z should not exist"
|
code/RL_model/verl/verl_train/tests/utils/test_nvtx_profile.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import unittest
|
| 17 |
+
from unittest.mock import MagicMock, patch
|
| 18 |
+
|
| 19 |
+
from verl.utils import omega_conf_to_dataclass
|
| 20 |
+
from verl.utils.profiler.config import NsightToolConfig, ProfilerConfig
|
| 21 |
+
from verl.utils.profiler.profile import DistProfiler
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class TestProfilerConfig(unittest.TestCase):
|
| 25 |
+
def test_config_init(self):
|
| 26 |
+
import os
|
| 27 |
+
|
| 28 |
+
from hydra import compose, initialize_config_dir
|
| 29 |
+
|
| 30 |
+
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
|
| 31 |
+
cfg = compose(config_name="ppo_trainer")
|
| 32 |
+
for config in [
|
| 33 |
+
cfg.actor_rollout_ref.actor.profiler,
|
| 34 |
+
cfg.actor_rollout_ref.rollout.profiler,
|
| 35 |
+
cfg.actor_rollout_ref.ref.profiler,
|
| 36 |
+
cfg.critic.profiler,
|
| 37 |
+
cfg.reward_model.profiler,
|
| 38 |
+
]:
|
| 39 |
+
profiler_config = omega_conf_to_dataclass(config)
|
| 40 |
+
self.assertEqual(profiler_config.tool, config.tool)
|
| 41 |
+
self.assertEqual(profiler_config.enable, config.enable)
|
| 42 |
+
self.assertEqual(profiler_config.all_ranks, config.all_ranks)
|
| 43 |
+
self.assertEqual(profiler_config.ranks, config.ranks)
|
| 44 |
+
self.assertEqual(profiler_config.save_path, config.save_path)
|
| 45 |
+
self.assertEqual(profiler_config.ranks, config.ranks)
|
| 46 |
+
assert isinstance(profiler_config, ProfilerConfig)
|
| 47 |
+
with self.assertRaises(AttributeError):
|
| 48 |
+
_ = profiler_config.non_existing_key
|
| 49 |
+
assert config.get("non_existing_key") == profiler_config.get("non_existing_key")
|
| 50 |
+
assert config.get("non_existing_key", 1) == profiler_config.get("non_existing_key", 1)
|
| 51 |
+
|
| 52 |
+
def test_frozen_config(self):
|
| 53 |
+
"""Test that modifying frozen keys in ProfilerConfig raises exceptions."""
|
| 54 |
+
from dataclasses import FrozenInstanceError
|
| 55 |
+
|
| 56 |
+
from verl.utils.profiler.config import ProfilerConfig
|
| 57 |
+
|
| 58 |
+
# Create a new ProfilerConfig instance
|
| 59 |
+
config = ProfilerConfig(all_ranks=False, ranks=[0])
|
| 60 |
+
|
| 61 |
+
with self.assertRaises(FrozenInstanceError):
|
| 62 |
+
config.all_ranks = True
|
| 63 |
+
|
| 64 |
+
with self.assertRaises(FrozenInstanceError):
|
| 65 |
+
config.ranks = [1, 2, 3]
|
| 66 |
+
|
| 67 |
+
with self.assertRaises(TypeError):
|
| 68 |
+
config["all_ranks"] = True
|
| 69 |
+
|
| 70 |
+
with self.assertRaises(TypeError):
|
| 71 |
+
config["ranks"] = [1, 2, 3]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class TestNsightSystemsProfiler(unittest.TestCase):
|
| 75 |
+
"""Test suite for NsightSystemsProfiler functionality.
|
| 76 |
+
|
| 77 |
+
Test Plan:
|
| 78 |
+
1. Initialization: Verify profiler state after creation
|
| 79 |
+
2. Basic Profiling: Test start/stop functionality
|
| 80 |
+
3. Discrete Mode: TODO: Test discrete profiling behavior
|
| 81 |
+
4. Annotation: Test the annotate decorator in both normal and discrete modes
|
| 82 |
+
5. Config Validation: Verify proper config initialization from OmegaConf
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
def setUp(self):
|
| 86 |
+
self.config = ProfilerConfig(tool="nsys", enable=True, all_ranks=True)
|
| 87 |
+
self.rank = 0
|
| 88 |
+
self.profiler = DistProfiler(self.rank, self.config, tool_config=NsightToolConfig(discrete=False))
|
| 89 |
+
|
| 90 |
+
def test_initialization(self):
|
| 91 |
+
self.assertEqual(self.profiler.check_this_rank(), True)
|
| 92 |
+
self.assertEqual(self.profiler.check_this_step(), False)
|
| 93 |
+
|
| 94 |
+
def test_start_stop_profiling(self):
|
| 95 |
+
with patch("torch.cuda.profiler.start") as mock_start, patch("torch.cuda.profiler.stop") as mock_stop:
|
| 96 |
+
# Test start
|
| 97 |
+
self.profiler.start()
|
| 98 |
+
self.assertTrue(self.profiler.check_this_step())
|
| 99 |
+
mock_start.assert_called_once()
|
| 100 |
+
|
| 101 |
+
# Test stop
|
| 102 |
+
self.profiler.stop()
|
| 103 |
+
self.assertFalse(self.profiler.check_this_step())
|
| 104 |
+
mock_stop.assert_called_once()
|
| 105 |
+
|
| 106 |
+
# def test_discrete_profiling(self):
|
| 107 |
+
# discrete_config = ProfilerConfig(discrete=True, all_ranks=True)
|
| 108 |
+
# profiler = NsightSystemsProfiler(self.rank, discrete_config)
|
| 109 |
+
|
| 110 |
+
# with patch("torch.cuda.profiler.start") as mock_start, patch("torch.cuda.profiler.stop") as mock_stop:
|
| 111 |
+
# profiler.start()
|
| 112 |
+
# self.assertTrue(profiler.this_step)
|
| 113 |
+
# mock_start.assert_not_called() # Shouldn't start immediately in discrete mode
|
| 114 |
+
|
| 115 |
+
# profiler.stop()
|
| 116 |
+
# self.assertFalse(profiler.this_step)
|
| 117 |
+
# mock_stop.assert_not_called() # Shouldn't stop immediately in discrete mode
|
| 118 |
+
|
| 119 |
+
def test_annotate_decorator(self):
|
| 120 |
+
mock_self = MagicMock()
|
| 121 |
+
mock_self.profiler = self.profiler
|
| 122 |
+
mock_self.profiler.start()
|
| 123 |
+
decorator = mock_self.profiler.annotate(message="test")
|
| 124 |
+
|
| 125 |
+
@decorator
|
| 126 |
+
def test_func(self, *args, **kwargs):
|
| 127 |
+
return "result"
|
| 128 |
+
|
| 129 |
+
with (
|
| 130 |
+
patch("torch.cuda.profiler.start") as mock_start,
|
| 131 |
+
patch("torch.cuda.profiler.stop") as mock_stop,
|
| 132 |
+
patch("verl.utils.profiler.nvtx_profile.mark_start_range") as mock_start_range,
|
| 133 |
+
patch("verl.utils.profiler.nvtx_profile.mark_end_range") as mock_end_range,
|
| 134 |
+
):
|
| 135 |
+
result = test_func(mock_self)
|
| 136 |
+
self.assertEqual(result, "result")
|
| 137 |
+
mock_start_range.assert_called_once()
|
| 138 |
+
mock_end_range.assert_called_once()
|
| 139 |
+
mock_start.assert_not_called() # Not discrete mode
|
| 140 |
+
mock_stop.assert_not_called() # Not discrete mode
|
| 141 |
+
|
| 142 |
+
# def test_annotate_discrete_mode(self):
|
| 143 |
+
# discrete_config = ProfilerConfig(discrete=True, all_ranks=True)
|
| 144 |
+
# profiler = NsightSystemsProfiler(self.rank, discrete_config)
|
| 145 |
+
# mock_self = MagicMock()
|
| 146 |
+
# mock_self.profiler = profiler
|
| 147 |
+
# mock_self.profiler.this_step = True
|
| 148 |
+
|
| 149 |
+
# @NsightSystemsProfiler.annotate(message="test")
|
| 150 |
+
# def test_func(self, *args, **kwargs):
|
| 151 |
+
# return "result"
|
| 152 |
+
|
| 153 |
+
# with (
|
| 154 |
+
# patch("torch.cuda.profiler.start") as mock_start,
|
| 155 |
+
# patch("torch.cuda.profiler.stop") as mock_stop,
|
| 156 |
+
# patch("verl.utils.profiler.nvtx_profile.mark_start_range") as mock_start_range,
|
| 157 |
+
# patch("verl.utils.profiler.nvtx_profile.mark_end_range") as mock_end_range,
|
| 158 |
+
# ):
|
| 159 |
+
# result = test_func(mock_self)
|
| 160 |
+
# self.assertEqual(result, "result")
|
| 161 |
+
# mock_start_range.assert_called_once()
|
| 162 |
+
# mock_end_range.assert_called_once()
|
| 163 |
+
# mock_start.assert_called_once() # Should start in discrete mode
|
| 164 |
+
# mock_stop.assert_called_once() # Should stop in discrete mode
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
if __name__ == "__main__":
|
| 168 |
+
unittest.main()
|
code/RL_model/verl/verl_train/tests/utils/test_rollout_skip_on_cpu.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import shutil
|
| 15 |
+
import tempfile
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from unittest.mock import MagicMock
|
| 18 |
+
|
| 19 |
+
import pytest
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from verl.utils.rollout_skip import DataProto, RolloutSkip
|
| 23 |
+
|
| 24 |
+
len_prompt = 50
|
| 25 |
+
len_response = 100
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def temp_dir():
|
| 29 |
+
# Create a temporary directory
|
| 30 |
+
temp_dir = Path(tempfile.mkdtemp())
|
| 31 |
+
yield temp_dir
|
| 32 |
+
# Cleanup
|
| 33 |
+
shutil.rmtree(temp_dir)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def build_generate_fn(gen_bs, n):
|
| 37 |
+
len_tokenizer = 1024
|
| 38 |
+
|
| 39 |
+
def iterate():
|
| 40 |
+
while True:
|
| 41 |
+
prompt = torch.randint(len_tokenizer, size=(gen_bs, len_prompt)).repeat_interleave(n, dim=0)
|
| 42 |
+
generate = torch.randint(len_tokenizer, size=(gen_bs * n, len_response))
|
| 43 |
+
data = DataProto.from_dict(tensors={"prompt": prompt, "response": generate})
|
| 44 |
+
yield data
|
| 45 |
+
|
| 46 |
+
mock_infer_engine = iterate()
|
| 47 |
+
|
| 48 |
+
def fn(batch, **kwargs):
|
| 49 |
+
# Simulate the inference engine returning the next batch
|
| 50 |
+
return next(mock_infer_engine)
|
| 51 |
+
|
| 52 |
+
return fn
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@pytest.fixture(params=[(32, 4), (64, 4), (64, 8)])
|
| 56 |
+
def mock_rollout_wg(request):
|
| 57 |
+
gen_bs, n = request.param
|
| 58 |
+
rollout_wg = MagicMock()
|
| 59 |
+
|
| 60 |
+
config = MagicMock()
|
| 61 |
+
config.actor_rollout_ref.rollout = {
|
| 62 |
+
"n": n,
|
| 63 |
+
"skip_dump_dir": next(temp_dir()),
|
| 64 |
+
}
|
| 65 |
+
config.data = {"gen_batch_size": gen_bs}
|
| 66 |
+
|
| 67 |
+
rollout_wg.generate_sequences = build_generate_fn(gen_bs, n)
|
| 68 |
+
|
| 69 |
+
yield config, rollout_wg
|
| 70 |
+
# Cleanup
|
| 71 |
+
shutil.rmtree(next(temp_dir()))
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class TestRolloutSkip:
|
| 75 |
+
def test_initialization(self, capsys):
|
| 76 |
+
"""Test that RolloutSkip initializes correctly"""
|
| 77 |
+
config = MagicMock()
|
| 78 |
+
config.actor_rollout_ref.rollout = {
|
| 79 |
+
"n": 16,
|
| 80 |
+
"skip_dump_dir": "tmp/rollout_dump",
|
| 81 |
+
}
|
| 82 |
+
config.data = {"gen_batch_size": 128}
|
| 83 |
+
mock_rollout_wg = MagicMock()
|
| 84 |
+
skip = RolloutSkip(config, mock_rollout_wg)
|
| 85 |
+
|
| 86 |
+
assert skip.n == 16
|
| 87 |
+
assert skip.gbs == 128
|
| 88 |
+
assert str(skip.dumped_dir) == "tmp/rollout_dump"
|
| 89 |
+
|
| 90 |
+
assert skip._rollout_wg == mock_rollout_wg
|
| 91 |
+
skip.wrap_generate_sequences()
|
| 92 |
+
captured = capsys.readouterr()
|
| 93 |
+
assert "Successfully patched" in captured.out
|
| 94 |
+
|
| 95 |
+
def test_generate_without_wrap(self, mock_rollout_wg):
|
| 96 |
+
"""Test that generate_sequences works without wrapping"""
|
| 97 |
+
|
| 98 |
+
config, rollout_wg = mock_rollout_wg
|
| 99 |
+
_ = RolloutSkip(config, rollout_wg)
|
| 100 |
+
|
| 101 |
+
_result = rollout_wg.generate_sequences(MagicMock())
|
| 102 |
+
for _ in range(10):
|
| 103 |
+
result = rollout_wg.generate_sequences(MagicMock())
|
| 104 |
+
assert isinstance(result, DataProto)
|
| 105 |
+
# * make sure the data is different
|
| 106 |
+
assert torch.abs(_result.batch["prompt"] - result.batch["prompt"]).sum() > 0
|
| 107 |
+
assert torch.abs(_result.batch["response"] - result.batch["response"]).sum() > 0
|
| 108 |
+
_result = result
|
| 109 |
+
|
| 110 |
+
def test_dump(self, mock_rollout_wg, capsys):
|
| 111 |
+
config, rollout_wg = mock_rollout_wg
|
| 112 |
+
skip = RolloutSkip(config, rollout_wg)
|
| 113 |
+
skip.wrap_generate_sequences()
|
| 114 |
+
|
| 115 |
+
result = rollout_wg.generate_sequences(MagicMock())
|
| 116 |
+
# * check if dump is OK
|
| 117 |
+
assert skip.curr_path_dump.exists()
|
| 118 |
+
captured = capsys.readouterr()
|
| 119 |
+
assert "Successfully dump data in" in captured.out
|
| 120 |
+
# * get file size, estimate file size
|
| 121 |
+
file_size = skip.curr_path_dump.stat().st_size
|
| 122 |
+
est_file_size = (len_prompt + len_response) * skip.gbs * skip.n * result.batch["prompt"].dtype.itemsize
|
| 123 |
+
assert file_size >= est_file_size, "Dumped file size is smaller than expected"
|
| 124 |
+
|
| 125 |
+
def test_generate_with_wrap(self, mock_rollout_wg, capsys):
|
| 126 |
+
"""Test that generate_sequences works without wrapping"""
|
| 127 |
+
|
| 128 |
+
config, rollout_wg = mock_rollout_wg
|
| 129 |
+
skip = RolloutSkip(config, rollout_wg)
|
| 130 |
+
skip.wrap_generate_sequences()
|
| 131 |
+
|
| 132 |
+
_result = rollout_wg.generate_sequences(MagicMock())
|
| 133 |
+
|
| 134 |
+
for _ in range(10):
|
| 135 |
+
result = rollout_wg.generate_sequences(MagicMock())
|
| 136 |
+
assert isinstance(result, DataProto)
|
| 137 |
+
# * make sure the data is different
|
| 138 |
+
assert torch.abs(_result.batch["prompt"] - result.batch["prompt"]).sum() == 0
|
| 139 |
+
assert torch.abs(_result.batch["response"] - result.batch["response"]).sum() == 0
|
| 140 |
+
captured = capsys.readouterr()
|
| 141 |
+
assert "Successfully load pre-generated data from" in captured.out
|
| 142 |
+
_result = result
|
code/RL_model/verl/verl_train/tests/utils/test_rollout_trace_on_cpu.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import sys
|
| 17 |
+
from unittest.mock import MagicMock, patch
|
| 18 |
+
|
| 19 |
+
import pytest
|
| 20 |
+
|
| 21 |
+
from verl.utils.rollout_trace import RolloutTraceConfig, rollout_trace_attr, rollout_trace_op
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@pytest.fixture(autouse=True)
|
| 25 |
+
def reset_rollout_trace_config_singleton():
|
| 26 |
+
"""Fixture to reset the RolloutTraceConfig singleton before each test."""
|
| 27 |
+
RolloutTraceConfig.reset()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@pytest.fixture
|
| 31 |
+
def mock_weave_client():
|
| 32 |
+
"""Mocks the weave module and its client, yielding the mock client."""
|
| 33 |
+
mock_weave = MagicMock()
|
| 34 |
+
mock_client = MagicMock()
|
| 35 |
+
mock_call = MagicMock()
|
| 36 |
+
mock_client.create_call.return_value = mock_call
|
| 37 |
+
mock_weave.init.return_value = mock_client
|
| 38 |
+
|
| 39 |
+
# Also mock the call_context if it's used internally by the decorator
|
| 40 |
+
mock_weave.trace.context.call_context.return_value = MagicMock()
|
| 41 |
+
|
| 42 |
+
with patch.dict(sys.modules, {"weave": mock_weave, "weave.trace.context": mock_weave.trace.context}):
|
| 43 |
+
yield mock_client
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class TracedClass:
|
| 47 |
+
@rollout_trace_op
|
| 48 |
+
# @weave.op
|
| 49 |
+
# @mlflow.trace
|
| 50 |
+
async def my_method(self, a, b="default"):
|
| 51 |
+
return f"result: {a}, {b}"
|
| 52 |
+
|
| 53 |
+
@rollout_trace_op
|
| 54 |
+
# @weave.op
|
| 55 |
+
# @mlflow.trace
|
| 56 |
+
async def middle_method(self, a, b="default"):
|
| 57 |
+
await self.my_method("test_a1", b="test_b1")
|
| 58 |
+
return f"result: {a}, {b}"
|
| 59 |
+
|
| 60 |
+
@rollout_trace_op
|
| 61 |
+
# @mlflow.trace
|
| 62 |
+
async def my_method_with_exception(self):
|
| 63 |
+
raise ValueError("Test Exception")
|
| 64 |
+
|
| 65 |
+
async def upper_method(self):
|
| 66 |
+
await self.my_method("test_a0", b="test_b0")
|
| 67 |
+
await self.middle_method("test_a2", b="test_b2")
|
| 68 |
+
return True
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class UntracedClass:
|
| 72 |
+
@rollout_trace_op
|
| 73 |
+
async def my_method(self, x):
|
| 74 |
+
return x * 2
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
async def test_rollout_trace_on_untraced_class():
|
| 78 |
+
"""Tests that the decorator works correctly when no backend is configured."""
|
| 79 |
+
instance = UntracedClass()
|
| 80 |
+
assert await instance.my_method(10) == 20
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
async def test_rollout_trace_with_tracer(mock_weave_client):
|
| 84 |
+
"""Tests that the decorator calls the tracer's methods correctly."""
|
| 85 |
+
RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="weave")
|
| 86 |
+
instance = TracedClass()
|
| 87 |
+
assert RolloutTraceConfig.get_client() is mock_weave_client
|
| 88 |
+
|
| 89 |
+
result = await instance.my_method("test_a", b="test_b")
|
| 90 |
+
|
| 91 |
+
assert result == "result: test_a, test_b"
|
| 92 |
+
mock_weave_client.create_call.assert_called_once()
|
| 93 |
+
call_kwargs = mock_weave_client.create_call.call_args.kwargs
|
| 94 |
+
assert call_kwargs["op"] == "TracedClass.my_method"
|
| 95 |
+
expected_inputs = {"a": "test_a", "b": "test_b"}
|
| 96 |
+
assert call_kwargs["inputs"] == expected_inputs
|
| 97 |
+
|
| 98 |
+
mock_call = mock_weave_client.create_call.return_value
|
| 99 |
+
mock_weave_client.finish_call.assert_called_once_with(mock_call, output=result)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
async def test_rollout_trace_with_exception(mock_weave_client):
|
| 103 |
+
"""Tests that `finish` is called with the exception when one is raised."""
|
| 104 |
+
RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="weave")
|
| 105 |
+
instance = TracedClass()
|
| 106 |
+
|
| 107 |
+
with pytest.raises(ValueError, match="Test Exception"):
|
| 108 |
+
await instance.my_method_with_exception()
|
| 109 |
+
|
| 110 |
+
mock_weave_client.create_call.assert_called_once()
|
| 111 |
+
mock_call = mock_weave_client.create_call.return_value
|
| 112 |
+
mock_weave_client.finish_call.assert_called_once()
|
| 113 |
+
|
| 114 |
+
# Check that finish_call was called with the exception
|
| 115 |
+
args, kwargs = mock_weave_client.finish_call.call_args
|
| 116 |
+
assert args[0] == mock_call
|
| 117 |
+
assert "exception" in kwargs
|
| 118 |
+
assert isinstance(kwargs["exception"], ValueError)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
async def test_rollout_trace_with_dummy_backend(mock_weave_client):
|
| 122 |
+
"""Tests that the tracer is not called when the backend is 'dummy'."""
|
| 123 |
+
RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="dummy")
|
| 124 |
+
instance = TracedClass()
|
| 125 |
+
|
| 126 |
+
await instance.my_method("test_a")
|
| 127 |
+
|
| 128 |
+
mock_weave_client.create_call.assert_not_called()
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
async def test_trace_disabled_with_trace_false(mock_weave_client):
|
| 132 |
+
"""Tests that tracing is disabled when trace=False."""
|
| 133 |
+
RolloutTraceConfig.init(
|
| 134 |
+
project_name="my-project",
|
| 135 |
+
experiment_name="my-experiment",
|
| 136 |
+
backend="weave",
|
| 137 |
+
)
|
| 138 |
+
instance = TracedClass()
|
| 139 |
+
|
| 140 |
+
assert RolloutTraceConfig.get_backend() == "weave"
|
| 141 |
+
|
| 142 |
+
with rollout_trace_attr(step=1, sample_index=0, rollout_n=0, trace=False):
|
| 143 |
+
result = await instance.my_method("test_a", b="test_b")
|
| 144 |
+
assert result == "result: test_a, test_b"
|
| 145 |
+
|
| 146 |
+
# No tracing should have occurred
|
| 147 |
+
mock_weave_client.create_call.assert_not_called()
|
| 148 |
+
|
| 149 |
+
# Verify that tracing works again with trace=True (default)
|
| 150 |
+
with rollout_trace_attr(step=1, sample_index=0, rollout_n=0):
|
| 151 |
+
result = await instance.my_method("test_a", b="test_b")
|
| 152 |
+
assert result == "result: test_a, test_b"
|
| 153 |
+
|
| 154 |
+
assert mock_weave_client.create_call.call_count == 1
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
async def test_trace_false_disables_nested_trace_ops(mock_weave_client):
|
| 158 |
+
"""Tests that trace=False disables all nested @rollout_trace_op calls."""
|
| 159 |
+
RolloutTraceConfig.init(
|
| 160 |
+
project_name="my-project",
|
| 161 |
+
experiment_name="my-experiment",
|
| 162 |
+
backend="weave",
|
| 163 |
+
)
|
| 164 |
+
instance = TracedClass()
|
| 165 |
+
|
| 166 |
+
with rollout_trace_attr(step=1, sample_index=0, rollout_n=0, trace=False):
|
| 167 |
+
# Call upper_method which internally calls my_method and middle_method
|
| 168 |
+
# All of these are decorated with @rollout_trace_op
|
| 169 |
+
result = await instance.upper_method()
|
| 170 |
+
assert result is True
|
| 171 |
+
|
| 172 |
+
# No tracing should have occurred for any of the nested calls
|
| 173 |
+
mock_weave_client.create_call.assert_not_called()
|
| 174 |
+
|
| 175 |
+
with rollout_trace_attr(step=1, sample_index=0, rollout_n=0):
|
| 176 |
+
result = await instance.my_method("test_a", b="test_b")
|
| 177 |
+
assert result == "result: test_a, test_b"
|
| 178 |
+
|
| 179 |
+
assert mock_weave_client.create_call.call_count == 1
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
async def test_trace_enabled_restored_after_exception(mock_weave_client):
|
| 183 |
+
"""Tests that trace state is restored even if an exception occurs when trace=False."""
|
| 184 |
+
RolloutTraceConfig.init(
|
| 185 |
+
project_name="my-project",
|
| 186 |
+
experiment_name="my-experiment",
|
| 187 |
+
backend="weave",
|
| 188 |
+
)
|
| 189 |
+
instance = TracedClass()
|
| 190 |
+
|
| 191 |
+
assert RolloutTraceConfig.get_backend() == "weave"
|
| 192 |
+
|
| 193 |
+
# Use trace=False and raise an exception
|
| 194 |
+
try:
|
| 195 |
+
with rollout_trace_attr(step=1, sample_index=0, rollout_n=0, trace=False):
|
| 196 |
+
raise RuntimeError("Test exception with trace disabled")
|
| 197 |
+
except RuntimeError:
|
| 198 |
+
pass
|
| 199 |
+
|
| 200 |
+
with rollout_trace_attr(step=1, sample_index=0, rollout_n=0):
|
| 201 |
+
result = await instance.my_method("test_a", b="test_b")
|
| 202 |
+
assert result == "result: test_a, test_b"
|
| 203 |
+
|
| 204 |
+
assert mock_weave_client.create_call.call_count == 1
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
@pytest.mark.skipif(
|
| 208 |
+
os.environ.get("RUN_WEAVE_INTEGRATION_TESTS", "false").lower() != "true",
|
| 209 |
+
reason="Skipping weave integration test. Set RUN_WEAVE_INTEGRATION_TESTS=true to run.",
|
| 210 |
+
)
|
| 211 |
+
async def test_rollout_trace_with_real_weave_backend():
|
| 212 |
+
"""Integration test with a real weave backend."""
|
| 213 |
+
|
| 214 |
+
# This assumes that the weave environment (e.g., project) is configured
|
| 215 |
+
RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="weave")
|
| 216 |
+
|
| 217 |
+
instance = TracedClass()
|
| 218 |
+
|
| 219 |
+
with rollout_trace_attr(step=1, sample_index=2, rollout_n=3):
|
| 220 |
+
await instance.upper_method()
|
| 221 |
+
|
| 222 |
+
with pytest.raises(ValueError, match="Test Exception"):
|
| 223 |
+
await instance.my_method_with_exception()
|
| 224 |
+
|
| 225 |
+
print("\nWeave integration test ran successfully. Check your weave project for the trace.")
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
@pytest.mark.skipif(
|
| 229 |
+
os.environ.get("RUN_MLFLOW_INTEGRATION_TESTS", "false").lower() != "true",
|
| 230 |
+
reason="Skipping mlflow integration test. Set RUN_MLFLOW_INTEGRATION_TESTS=true to run.",
|
| 231 |
+
)
|
| 232 |
+
async def test_rollout_trace_with_real_mlflow_backend():
|
| 233 |
+
"""Integration test with a real mlflow backend."""
|
| 234 |
+
|
| 235 |
+
# This assumes that the mlflow environment (e.g., project) is configured
|
| 236 |
+
RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="mlflow")
|
| 237 |
+
|
| 238 |
+
instance = TracedClass()
|
| 239 |
+
|
| 240 |
+
with rollout_trace_attr(step=1, sample_index=2, rollout_n=3, name="agent_run"):
|
| 241 |
+
assert await instance.upper_method()
|
| 242 |
+
|
| 243 |
+
# with pytest.raises(ValueError, match="Test Exception"):
|
| 244 |
+
# await instance.my_method_with_exception()
|
| 245 |
+
|
| 246 |
+
print("\nWeave integration test ran successfully. Check your weave project for the trace.")
|
code/RL_model/verl/verl_train/tests/utils/test_seqlen_balancing.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.distributed as dist
|
| 17 |
+
import torch.multiprocessing as mp
|
| 18 |
+
|
| 19 |
+
from verl import DataProto
|
| 20 |
+
from verl.utils.device import get_device_name, get_nccl_backend, get_torch_device
|
| 21 |
+
from verl.utils.model import create_random_mask
|
| 22 |
+
from verl.utils.seqlen_balancing import (
|
| 23 |
+
ceildiv,
|
| 24 |
+
get_reverse_idx,
|
| 25 |
+
prepare_dynamic_batch,
|
| 26 |
+
rearrange_micro_batches,
|
| 27 |
+
restore_dynamic_batch,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def test_seqlen_balancing():
|
| 32 |
+
input_ids = torch.randint(low=0, high=10, size=(20, 100))
|
| 33 |
+
|
| 34 |
+
attention_mask = create_random_mask(
|
| 35 |
+
input_ids=input_ids, max_ratio_of_left_padding=0.1, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.5
|
| 36 |
+
)
|
| 37 |
+
data = {"input_ids": input_ids, "attention_mask": attention_mask}
|
| 38 |
+
dataproto = DataProto.from_single_dict(data)
|
| 39 |
+
micro_batches, micro_bsz_idx_lst = rearrange_micro_batches(dataproto.batch, max_token_len=300)
|
| 40 |
+
batch = torch.cat(micro_batches)
|
| 41 |
+
micro_bsz_idx = []
|
| 42 |
+
for idx in micro_bsz_idx_lst:
|
| 43 |
+
micro_bsz_idx.extend(idx)
|
| 44 |
+
reverse_idx_map = get_reverse_idx(micro_bsz_idx)
|
| 45 |
+
reverse_idx_map = torch.tensor(reverse_idx_map)
|
| 46 |
+
new_batch = batch[reverse_idx_map]
|
| 47 |
+
torch.testing.assert_close(new_batch, dataproto.batch)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def test_dynamic_batch():
|
| 51 |
+
input_ids = torch.randint(low=0, high=10, size=(20, 100))
|
| 52 |
+
|
| 53 |
+
attention_mask = create_random_mask(
|
| 54 |
+
input_ids=input_ids, max_ratio_of_left_padding=0.1, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.5
|
| 55 |
+
)
|
| 56 |
+
data = {"input_ids": input_ids, "attention_mask": attention_mask}
|
| 57 |
+
dataproto = DataProto.from_single_dict(data)
|
| 58 |
+
micro_batches, micro_bsz_idx_lst = prepare_dynamic_batch(dataproto, max_token_len=300)
|
| 59 |
+
input_ids = torch.cat([micro_batch.batch["input_ids"] for micro_batch in micro_batches], dim=0)
|
| 60 |
+
input_ids = restore_dynamic_batch(input_ids, micro_bsz_idx_lst)
|
| 61 |
+
torch.testing.assert_close(input_ids, dataproto.batch["input_ids"])
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _worker(rank, world_size, init_method, max_token_len, use_same_dp, min_mb):
|
| 65 |
+
# 1) init process group & CUDA
|
| 66 |
+
get_torch_device().set_device(rank)
|
| 67 |
+
dist.init_process_group(
|
| 68 |
+
backend=get_nccl_backend(),
|
| 69 |
+
init_method=init_method,
|
| 70 |
+
world_size=world_size,
|
| 71 |
+
rank=rank,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# 2) build a small random batch (each rank different length to force mismatch)
|
| 75 |
+
torch.manual_seed(42 + rank)
|
| 76 |
+
input_ids = torch.randint(0, 10, (20 + rank * 5, 100), device=f"{get_device_name()}:{rank}")
|
| 77 |
+
attention_mask = create_random_mask(
|
| 78 |
+
input_ids=input_ids,
|
| 79 |
+
max_ratio_of_left_padding=0.1,
|
| 80 |
+
max_ratio_of_valid_token=0.9,
|
| 81 |
+
min_ratio_of_valid_token=0.5,
|
| 82 |
+
)
|
| 83 |
+
dp = {"input_ids": input_ids, "attention_mask": attention_mask}
|
| 84 |
+
proto = DataProto.from_single_dict(dp)
|
| 85 |
+
batch = proto.batch
|
| 86 |
+
|
| 87 |
+
# 3) call rearrange_micro_batches with one of the two params under test
|
| 88 |
+
micros, idx_lst = rearrange_micro_batches(
|
| 89 |
+
batch,
|
| 90 |
+
max_token_len=max_token_len,
|
| 91 |
+
dp_group=dist.group.WORLD,
|
| 92 |
+
same_micro_num_in_dp=use_same_dp,
|
| 93 |
+
min_num_micro_batch=min_mb,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# 4) check the enforced counts
|
| 97 |
+
seq_len_effective: torch.Tensor = batch["attention_mask"].sum(dim=1)
|
| 98 |
+
total_seqlen = seq_len_effective.sum().item()
|
| 99 |
+
local = min(len(seq_len_effective), ceildiv(total_seqlen, max_token_len))
|
| 100 |
+
|
| 101 |
+
if min_mb is not None:
|
| 102 |
+
expected = max(local, min_mb)
|
| 103 |
+
assert len(micros) == expected
|
| 104 |
+
if use_same_dp:
|
| 105 |
+
# gather all local_counts
|
| 106 |
+
counts = [torch.zeros(1, device=f"{get_device_name()}:{rank}") for _ in range(world_size)]
|
| 107 |
+
counts[rank].fill_(local)
|
| 108 |
+
dist.all_gather(counts, counts[rank])
|
| 109 |
+
expected = max(int(c.item()) for c in counts)
|
| 110 |
+
assert len(micros) == expected
|
| 111 |
+
else:
|
| 112 |
+
# if neither, we get the local natural count
|
| 113 |
+
assert len(micros) == local
|
| 114 |
+
|
| 115 |
+
# 5) reconstruction sanity: concat→reverse_idx→orig
|
| 116 |
+
flat = torch.cat(micros, dim=0)
|
| 117 |
+
idx = []
|
| 118 |
+
for sub in idx_lst:
|
| 119 |
+
idx.extend(sub)
|
| 120 |
+
inv = get_reverse_idx(idx)
|
| 121 |
+
inv = torch.tensor(inv, device=flat.device)
|
| 122 |
+
reconstructed = flat[inv]
|
| 123 |
+
torch.testing.assert_close(reconstructed, batch)
|
| 124 |
+
|
| 125 |
+
dist.destroy_process_group()
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def test_dataproto_split_uneven():
|
| 129 |
+
"""Test DataProto.split with uneven splits"""
|
| 130 |
+
# Create test data with 10 items
|
| 131 |
+
input_ids = torch.randint(low=0, high=10, size=(10, 5))
|
| 132 |
+
attention_mask = torch.ones(10, 5)
|
| 133 |
+
data = {"input_ids": input_ids, "attention_mask": attention_mask}
|
| 134 |
+
dataproto = DataProto.from_single_dict(data)
|
| 135 |
+
|
| 136 |
+
# Test split with size 3 (should create chunks of [3, 3, 3, 1])
|
| 137 |
+
splits = dataproto.split(3)
|
| 138 |
+
assert len(splits) == 4
|
| 139 |
+
assert len(splits[0]) == 3
|
| 140 |
+
assert len(splits[1]) == 3
|
| 141 |
+
assert len(splits[2]) == 3
|
| 142 |
+
assert len(splits[3]) == 1
|
| 143 |
+
|
| 144 |
+
reconstructed = DataProto.concat(splits)
|
| 145 |
+
torch.testing.assert_close(reconstructed.batch["input_ids"], dataproto.batch["input_ids"])
|
| 146 |
+
torch.testing.assert_close(reconstructed.batch["attention_mask"], dataproto.batch["attention_mask"])
|
| 147 |
+
|
| 148 |
+
# Test split with size equal to length (should create one chunk)
|
| 149 |
+
splits = dataproto.split(10)
|
| 150 |
+
assert len(splits) == 1
|
| 151 |
+
assert len(splits[0]) == 10
|
| 152 |
+
|
| 153 |
+
# Test split with size larger than length (should create one chunk with all data)
|
| 154 |
+
splits = dataproto.split(15)
|
| 155 |
+
assert len(splits) == 1
|
| 156 |
+
assert len(splits[0]) == 10
|
| 157 |
+
|
| 158 |
+
# Test with non-tensor batch data
|
| 159 |
+
import numpy as np
|
| 160 |
+
|
| 161 |
+
data_with_non_tensor = {
|
| 162 |
+
"input_ids": input_ids,
|
| 163 |
+
"attention_mask": attention_mask,
|
| 164 |
+
"labels": np.array([f"label_{i}" for i in range(10)], dtype=object),
|
| 165 |
+
}
|
| 166 |
+
dataproto_with_non_tensor = DataProto.from_single_dict(data_with_non_tensor)
|
| 167 |
+
|
| 168 |
+
splits = dataproto_with_non_tensor.split(3)
|
| 169 |
+
assert len(splits) == 4
|
| 170 |
+
assert len(splits[0]) == 3
|
| 171 |
+
assert len(splits[1]) == 3
|
| 172 |
+
assert len(splits[2]) == 3
|
| 173 |
+
assert len(splits[3]) == 1
|
| 174 |
+
|
| 175 |
+
# Verify non-tensor data integrity
|
| 176 |
+
reconstructed = DataProto.concat(splits)
|
| 177 |
+
np.testing.assert_array_equal(
|
| 178 |
+
reconstructed.non_tensor_batch["labels"], dataproto_with_non_tensor.non_tensor_batch["labels"]
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def test_seqlen_balancing_distributed_params(tmp_path):
|
| 183 |
+
world_size = 2
|
| 184 |
+
init_file = tmp_path / "dist_init"
|
| 185 |
+
init_file.write_text("") # empty file
|
| 186 |
+
init_method = f"file://{init_file}"
|
| 187 |
+
|
| 188 |
+
# test min_num_micro_batch only
|
| 189 |
+
mp.spawn(
|
| 190 |
+
_worker,
|
| 191 |
+
args=(world_size, init_method, 300, False, 4),
|
| 192 |
+
nprocs=world_size,
|
| 193 |
+
join=True,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# test same_micro_num_in_dp only
|
| 197 |
+
mp.spawn(
|
| 198 |
+
_worker,
|
| 199 |
+
args=(world_size, init_method, 300, True, None),
|
| 200 |
+
nprocs=world_size,
|
| 201 |
+
join=True,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def test_group_balanced_partitions():
|
| 206 |
+
"""Test group-level balancing keeps same-uid samples together."""
|
| 207 |
+
from verl.utils.seqlen_balancing import get_group_balanced_partitions
|
| 208 |
+
|
| 209 |
+
# Create test data: 4 groups with different sizes
|
| 210 |
+
# Group 0 (uid=0): indices 0,1,2,3 with seqlens [100, 100, 100, 100]
|
| 211 |
+
# Group 1 (uid=1): indices 4,5,6,7 with seqlens [200, 200, 200, 200]
|
| 212 |
+
# Group 2 (uid=2): indices 8,9,10,11 with seqlens [150, 150, 150, 150]
|
| 213 |
+
# Group 3 (uid=3): indices 12,13,14,15 with seqlens [50, 50, 50, 50]
|
| 214 |
+
seqlen_list = [100] * 4 + [200] * 4 + [150] * 4 + [50] * 4
|
| 215 |
+
uid_list = [0] * 4 + [1] * 4 + [2] * 4 + [3] * 4
|
| 216 |
+
|
| 217 |
+
# Partition into 2 groups
|
| 218 |
+
partitions = get_group_balanced_partitions(seqlen_list, uid_list, k_partitions=2)
|
| 219 |
+
|
| 220 |
+
assert len(partitions) == 2
|
| 221 |
+
|
| 222 |
+
# Verify all indices are covered
|
| 223 |
+
all_indices = set()
|
| 224 |
+
for partition in partitions:
|
| 225 |
+
all_indices.update(partition)
|
| 226 |
+
assert all_indices == set(range(16))
|
| 227 |
+
|
| 228 |
+
# Verify same-uid samples stay together
|
| 229 |
+
for partition in partitions:
|
| 230 |
+
uids_in_partition = set(uid_list[i] for i in partition)
|
| 231 |
+
for uid in uids_in_partition:
|
| 232 |
+
# All samples with this uid should be in this partition
|
| 233 |
+
uid_indices = [i for i, u in enumerate(uid_list) if u == uid]
|
| 234 |
+
assert all(i in partition for i in uid_indices), f"uid {uid} samples split across partitions"
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def test_group_balanced_partitions_single_sample_groups():
|
| 238 |
+
"""Test group balancing with single-sample groups (n=1)."""
|
| 239 |
+
from verl.utils.seqlen_balancing import get_group_balanced_partitions
|
| 240 |
+
|
| 241 |
+
# Each sample is its own group
|
| 242 |
+
seqlen_list = [100, 200, 150, 50, 300, 250]
|
| 243 |
+
uid_list = [0, 1, 2, 3, 4, 5]
|
| 244 |
+
|
| 245 |
+
partitions = get_group_balanced_partitions(seqlen_list, uid_list, k_partitions=2)
|
| 246 |
+
|
| 247 |
+
assert len(partitions) == 2
|
| 248 |
+
all_indices = set()
|
| 249 |
+
for partition in partitions:
|
| 250 |
+
all_indices.update(partition)
|
| 251 |
+
assert all_indices == set(range(6))
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def test_group_balanced_partitions_equal_size():
|
| 255 |
+
"""Test group balancing with equal_size constraint simulation."""
|
| 256 |
+
from verl.utils.seqlen_balancing import get_group_balanced_partitions
|
| 257 |
+
|
| 258 |
+
# 8 groups, partition into 4 (simulating world_size=4)
|
| 259 |
+
# Each group has 2 samples
|
| 260 |
+
seqlen_list = [100, 100, 200, 200, 150, 150, 50, 50, 300, 300, 250, 250, 180, 180, 120, 120]
|
| 261 |
+
uid_list = [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7]
|
| 262 |
+
|
| 263 |
+
partitions = get_group_balanced_partitions(seqlen_list, uid_list, k_partitions=4)
|
| 264 |
+
|
| 265 |
+
assert len(partitions) == 4
|
| 266 |
+
|
| 267 |
+
# Verify all indices are covered
|
| 268 |
+
all_indices = set()
|
| 269 |
+
for partition in partitions:
|
| 270 |
+
all_indices.update(partition)
|
| 271 |
+
assert all_indices == set(range(16))
|
| 272 |
+
|
| 273 |
+
# Verify same-uid samples stay together
|
| 274 |
+
for partition in partitions:
|
| 275 |
+
uids_in_partition = set(uid_list[i] for i in partition)
|
| 276 |
+
for uid in uids_in_partition:
|
| 277 |
+
uid_indices = [i for i, u in enumerate(uid_list) if u == uid]
|
| 278 |
+
assert all(i in partition for i in uid_indices)
|
code/RL_model/verl/verl_train/tests/utils/test_shared_memory.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import multiprocessing
|
| 16 |
+
import unittest
|
| 17 |
+
from multiprocessing import shared_memory
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
from verl.workers.rollout.vllm_rollout.utils import create_shared_memory, rebuild_shared_memory
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class TestSharedMemory(unittest.TestCase):
|
| 25 |
+
"""Test cases for shared memory utility functions."""
|
| 26 |
+
|
| 27 |
+
def setUp(self):
|
| 28 |
+
"""Set up test fixtures before each test method."""
|
| 29 |
+
# Use short unique names to avoid POSIX shared memory name length limits
|
| 30 |
+
import uuid
|
| 31 |
+
|
| 32 |
+
short_id = uuid.uuid4().hex[:8]
|
| 33 |
+
self.test_name = f"shm_{short_id}"
|
| 34 |
+
|
| 35 |
+
def tearDown(self):
|
| 36 |
+
"""Clean up shared memory after each test method."""
|
| 37 |
+
# Note: We're relying on the OS to clean up shared memory
|
| 38 |
+
# as we properly delete all references in the tests
|
| 39 |
+
pass
|
| 40 |
+
|
| 41 |
+
def test_create_shared_memory_new(self):
|
| 42 |
+
"""Test creating new shared memory with unique name."""
|
| 43 |
+
size = 1024
|
| 44 |
+
|
| 45 |
+
shm = create_shared_memory(size, self.test_name)
|
| 46 |
+
|
| 47 |
+
# Verify shared memory object is created correctly
|
| 48 |
+
self.assertIsNotNone(shm)
|
| 49 |
+
# Note: shared memory may have system-dependent size rounding
|
| 50 |
+
self.assertGreaterEqual(shm.size, size)
|
| 51 |
+
self.assertEqual(shm.name, self.test_name)
|
| 52 |
+
|
| 53 |
+
# Clean up - delete tensor references first
|
| 54 |
+
del shm
|
| 55 |
+
|
| 56 |
+
def test_create_shared_memory_attach_existing(self):
|
| 57 |
+
"""Test that create_shared_memory attaches to existing shared memory when FileExistsError occurs."""
|
| 58 |
+
size = 2048
|
| 59 |
+
|
| 60 |
+
# First, create shared memory
|
| 61 |
+
shm1 = create_shared_memory(size, self.test_name)
|
| 62 |
+
self.assertGreaterEqual(shm1.size, size)
|
| 63 |
+
|
| 64 |
+
# Second call should attach to existing memory
|
| 65 |
+
shm2 = create_shared_memory(size, self.test_name)
|
| 66 |
+
|
| 67 |
+
# Verify we attached to the same shared memory
|
| 68 |
+
self.assertIsNotNone(shm2)
|
| 69 |
+
self.assertGreaterEqual(shm2.size, size)
|
| 70 |
+
self.assertEqual(shm2.name, self.test_name)
|
| 71 |
+
|
| 72 |
+
# Both should reference the same shared memory
|
| 73 |
+
self.assertEqual(shm1.name, shm2.name)
|
| 74 |
+
|
| 75 |
+
# Clean up
|
| 76 |
+
del shm1, shm2
|
| 77 |
+
|
| 78 |
+
def test_rebuild_shared_memory_default_dtype(self):
|
| 79 |
+
"""Test rebuilding tensor from shared memory with default dtype (uint8)."""
|
| 80 |
+
size = 1024
|
| 81 |
+
|
| 82 |
+
# Create and write to shared memory
|
| 83 |
+
shm = create_shared_memory(size, self.test_name)
|
| 84 |
+
test_data = torch.arange(size, dtype=torch.uint8)
|
| 85 |
+
shm.buf[:size] = test_data.numpy().tobytes()
|
| 86 |
+
|
| 87 |
+
# Rebuild tensor from shared memory
|
| 88 |
+
tensor, _ = rebuild_shared_memory(self.test_name, size)
|
| 89 |
+
|
| 90 |
+
# Verify tensor properties
|
| 91 |
+
self.assertEqual(tensor.dtype, torch.uint8)
|
| 92 |
+
self.assertEqual(len(tensor), size)
|
| 93 |
+
|
| 94 |
+
# Verify data integrity
|
| 95 |
+
reconstructed = torch.frombuffer(shm.buf[:size], dtype=torch.uint8)
|
| 96 |
+
self.assertTrue(torch.equal(tensor, reconstructed))
|
| 97 |
+
|
| 98 |
+
# Clean up - delete references before closing
|
| 99 |
+
del tensor, reconstructed
|
| 100 |
+
|
| 101 |
+
def test_rebuild_shared_memory_custom_dtype(self):
|
| 102 |
+
"""Test rebuilding tensor from shared memory with custom dtype."""
|
| 103 |
+
size = 256 # 256 bytes = 64 float32 values
|
| 104 |
+
|
| 105 |
+
# Create and write to shared memory
|
| 106 |
+
shm = create_shared_memory(size, self.test_name)
|
| 107 |
+
test_data = torch.arange(64, dtype=torch.float32)
|
| 108 |
+
shm.buf[:size] = test_data.numpy().tobytes()
|
| 109 |
+
|
| 110 |
+
# Rebuild tensor with custom dtype
|
| 111 |
+
tensor, _ = rebuild_shared_memory(self.test_name, size, dtype=torch.float32)
|
| 112 |
+
|
| 113 |
+
# Verify tensor properties
|
| 114 |
+
self.assertEqual(tensor.dtype, torch.float32)
|
| 115 |
+
self.assertEqual(len(tensor), 64)
|
| 116 |
+
|
| 117 |
+
# Verify data integrity
|
| 118 |
+
reconstructed = torch.frombuffer(shm.buf[:size], dtype=torch.float32)
|
| 119 |
+
self.assertTrue(torch.equal(tensor, reconstructed))
|
| 120 |
+
|
| 121 |
+
# Clean up - delete references before closing
|
| 122 |
+
del tensor, reconstructed
|
| 123 |
+
|
| 124 |
+
def test_shared_memory_data_integrity(self):
|
| 125 |
+
"""Test that data remains intact between create and rebuild operations."""
|
| 126 |
+
size = 512
|
| 127 |
+
|
| 128 |
+
# Create test data with various patterns
|
| 129 |
+
test_data = torch.randint(0, 256, (size,), dtype=torch.uint8)
|
| 130 |
+
|
| 131 |
+
# Create shared memory and write data
|
| 132 |
+
shm = create_shared_memory(size, self.test_name)
|
| 133 |
+
shm.buf[:size] = test_data.numpy().tobytes()
|
| 134 |
+
|
| 135 |
+
# Rebuild tensor
|
| 136 |
+
tensor, _ = rebuild_shared_memory(self.test_name, size)
|
| 137 |
+
|
| 138 |
+
# Verify data integrity
|
| 139 |
+
reconstructed = torch.frombuffer(shm.buf[:size], dtype=torch.uint8)
|
| 140 |
+
self.assertTrue(torch.equal(test_data, reconstructed))
|
| 141 |
+
|
| 142 |
+
# Clean up - delete references before closing
|
| 143 |
+
del tensor, reconstructed
|
| 144 |
+
|
| 145 |
+
def test_shared_memory_different_dtypes(self):
|
| 146 |
+
"""Test shared memory operations with different tensor dtypes."""
|
| 147 |
+
test_cases = [
|
| 148 |
+
(torch.float32, 256, 64), # 256 bytes / 4 bytes = 64 values
|
| 149 |
+
(torch.float64, 256, 32), # 256 bytes / 8 bytes = 32 values
|
| 150 |
+
(torch.int32, 256, 64), # 256 bytes / 4 bytes = 64 values
|
| 151 |
+
(torch.int64, 256, 32), # 256 bytes / 8 bytes = 32 values
|
| 152 |
+
(torch.uint8, 256, 256), # 256 bytes / 1 byte = 256 values
|
| 153 |
+
]
|
| 154 |
+
|
| 155 |
+
for dtype, size, expected_len in test_cases:
|
| 156 |
+
# Create test data
|
| 157 |
+
test_data = torch.arange(expected_len, dtype=dtype)
|
| 158 |
+
|
| 159 |
+
# Create shared memory and write data
|
| 160 |
+
shm = create_shared_memory(size, self.test_name)
|
| 161 |
+
shm.buf[:size] = test_data.numpy().tobytes()
|
| 162 |
+
|
| 163 |
+
# Rebuild tensor
|
| 164 |
+
tensor, _ = rebuild_shared_memory(self.test_name, size, dtype=dtype)
|
| 165 |
+
|
| 166 |
+
# Verify properties and data
|
| 167 |
+
self.assertEqual(tensor.dtype, dtype)
|
| 168 |
+
self.assertEqual(len(tensor), expected_len)
|
| 169 |
+
|
| 170 |
+
reconstructed = torch.frombuffer(shm.buf[:size], dtype=dtype)
|
| 171 |
+
self.assertTrue(torch.equal(test_data, reconstructed))
|
| 172 |
+
|
| 173 |
+
# Clean up - delete references before closing
|
| 174 |
+
del tensor, reconstructed
|
| 175 |
+
|
| 176 |
+
def test_shared_memory_multiple_operations(self):
|
| 177 |
+
"""Test multiple create/rebuild operations with the same name."""
|
| 178 |
+
size = 512
|
| 179 |
+
|
| 180 |
+
# First iteration
|
| 181 |
+
test_data1 = torch.arange(size, dtype=torch.uint8)
|
| 182 |
+
shm1 = create_shared_memory(size, self.test_name)
|
| 183 |
+
shm1.buf[:size] = test_data1.numpy().tobytes()
|
| 184 |
+
tensor1, _ = rebuild_shared_memory(self.test_name, size)
|
| 185 |
+
reconstructed1 = torch.frombuffer(shm1.buf[:size], dtype=torch.uint8)
|
| 186 |
+
self.assertTrue(torch.equal(test_data1, reconstructed1))
|
| 187 |
+
del tensor1, reconstructed1, shm1
|
| 188 |
+
|
| 189 |
+
# Second iteration with different data
|
| 190 |
+
test_data2 = torch.arange(size, dtype=torch.uint8) * 2
|
| 191 |
+
shm2 = create_shared_memory(size, self.test_name)
|
| 192 |
+
shm2.buf[:size] = test_data2.numpy().tobytes()
|
| 193 |
+
tensor2, _ = rebuild_shared_memory(self.test_name, size)
|
| 194 |
+
reconstructed2 = torch.frombuffer(shm2.buf[:size], dtype=torch.uint8)
|
| 195 |
+
self.assertTrue(torch.equal(test_data2, reconstructed2))
|
| 196 |
+
del tensor2, reconstructed2, shm2
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
# Module-level function for cross-process testing
|
| 200 |
+
def child_process_function(name, size, test_data_bytes):
|
| 201 |
+
"""Child process function to rebuild and verify tensor."""
|
| 202 |
+
shm = None
|
| 203 |
+
tensor = None
|
| 204 |
+
test_data = None
|
| 205 |
+
try:
|
| 206 |
+
# Convert bytes back to tensor
|
| 207 |
+
test_data = torch.frombuffer(test_data_bytes, dtype=torch.uint8)
|
| 208 |
+
|
| 209 |
+
# Attach to shared memory
|
| 210 |
+
shm = shared_memory.SharedMemory(name=name)
|
| 211 |
+
|
| 212 |
+
# Rebuild tensor from shared memory
|
| 213 |
+
tensor = torch.frombuffer(shm.buf[:size], dtype=torch.uint8)
|
| 214 |
+
|
| 215 |
+
# Verify data integrity
|
| 216 |
+
assert torch.equal(test_data, tensor), "Data mismatch in child process"
|
| 217 |
+
return True
|
| 218 |
+
except Exception as e:
|
| 219 |
+
print(f"Error in child process: {e}")
|
| 220 |
+
return False
|
| 221 |
+
finally:
|
| 222 |
+
# Clean up shared memory in child process
|
| 223 |
+
# Delete all references first
|
| 224 |
+
del tensor, test_data
|
| 225 |
+
if shm is not None:
|
| 226 |
+
shm.close()
|
| 227 |
+
# Note: Don't unlink in child process, parent will clean up
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class TestSharedMemoryIntegration(unittest.TestCase):
|
| 231 |
+
"""Integration tests for shared memory operations across process boundaries."""
|
| 232 |
+
|
| 233 |
+
def test_cross_process_shared_memory(self):
|
| 234 |
+
"""Test shared memory can be created in one process and accessed in another."""
|
| 235 |
+
size = 1024
|
| 236 |
+
test_data = torch.arange(size, dtype=torch.uint8)
|
| 237 |
+
|
| 238 |
+
# Create shared memory in parent process
|
| 239 |
+
shm = create_shared_memory(size, "test_cross_proc")
|
| 240 |
+
shm.buf[:size] = test_data.numpy().tobytes()
|
| 241 |
+
|
| 242 |
+
# Convert tensor to bytes for passing to child process
|
| 243 |
+
test_data_bytes = test_data.numpy().tobytes()
|
| 244 |
+
|
| 245 |
+
# Start child process
|
| 246 |
+
process = multiprocessing.Process(
|
| 247 |
+
target=child_process_function, args=("test_cross_proc", size, test_data_bytes)
|
| 248 |
+
)
|
| 249 |
+
process.start()
|
| 250 |
+
process.join(timeout=5)
|
| 251 |
+
|
| 252 |
+
# Verify child process completed successfully
|
| 253 |
+
self.assertEqual(process.exitcode, 0, "Child process failed")
|
| 254 |
+
|
| 255 |
+
# Clean up
|
| 256 |
+
del shm
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
if __name__ == "__main__":
|
| 260 |
+
unittest.main()
|
code/RL_model/verl/verl_train/tests/utils/test_special_linear_cross_entropy_tp.py
ADDED
|
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
#
|
| 17 |
+
|
| 18 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 19 |
+
#
|
| 20 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 21 |
+
# you may not use this file except in compliance with the License.
|
| 22 |
+
# You may obtain a copy of the License at
|
| 23 |
+
#
|
| 24 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 25 |
+
#
|
| 26 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 27 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 28 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 29 |
+
# See the License for the specific language governing permissions and
|
| 30 |
+
# limitations under the License.
|
| 31 |
+
|
| 32 |
+
import os
|
| 33 |
+
|
| 34 |
+
import torch
|
| 35 |
+
import torch.distributed as dist
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy
|
| 39 |
+
except ImportError:
|
| 40 |
+
# FIXME: remove these manually included paths
|
| 41 |
+
import sys
|
| 42 |
+
|
| 43 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../")))
|
| 44 |
+
finally:
|
| 45 |
+
from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy
|
| 46 |
+
|
| 47 |
+
import verl.utils.torch_functional as verl_F
|
| 48 |
+
|
| 49 |
+
compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True)
|
| 50 |
+
|
| 51 |
+
MAX_TEST_CASES = os.environ.get("MAX_TEST_CASES", 5)
|
| 52 |
+
VERIFY_TORCH_SELF = os.environ.get("VERIFY_TORCH_SELF", False)
|
| 53 |
+
LOW_MEMORY = os.environ.get("LOW_MEMORY", False)
|
| 54 |
+
LOW_MEMORY_DIV_FACTOR = os.environ.get("LOW_MEMORY_DIV_FACTOR", 16)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def run_torch_entropy(
|
| 58 |
+
hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, temperature: float, reduction="none"
|
| 59 |
+
) -> list[torch.Tensor]:
|
| 60 |
+
# [num_tokens, vocab_size]
|
| 61 |
+
if len(hidden.shape) > 2:
|
| 62 |
+
hidden = hidden.view(-1, hidden.shape[-1]) # [num_tokens, hidden_size]
|
| 63 |
+
if len(labels.shape) > 1:
|
| 64 |
+
labels = labels.view(-1)
|
| 65 |
+
logits = torch.matmul(
|
| 66 |
+
hidden.to(torch.float32),
|
| 67 |
+
weight.to(torch.float32) if weight.size(0) == hidden.size(1) else weight.T.to(torch.float32),
|
| 68 |
+
)
|
| 69 |
+
logits /= temperature
|
| 70 |
+
pd = torch.nn.functional.softmax(logits, dim=-1) # [num_tokens, vocab_size]
|
| 71 |
+
entropy_a = torch.logsumexp(logits, dim=-1) # [num_tokens]
|
| 72 |
+
entropy_b = torch.sum(pd * logits, dim=-1) # [num_tokens]
|
| 73 |
+
entropy = entropy_a - entropy_b
|
| 74 |
+
logprobs = torch.nn.functional.cross_entropy(logits, labels, reduction=reduction) # [num_tokens]
|
| 75 |
+
logprobs = torch.neg(logprobs)
|
| 76 |
+
return logprobs, entropy
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class TorchEntropyTP(torch.autograd.Function):
|
| 80 |
+
"""
|
| 81 |
+
it is used for testing the correctness of the kernel
|
| 82 |
+
it is not efficient and is not recommended to use in practice
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
@staticmethod
|
| 86 |
+
def forward(
|
| 87 |
+
ctx,
|
| 88 |
+
hidden: torch.Tensor,
|
| 89 |
+
weight: torch.Tensor,
|
| 90 |
+
labels: torch.Tensor,
|
| 91 |
+
temperature: float,
|
| 92 |
+
dist_process_group: torch.distributed.ProcessGroup,
|
| 93 |
+
):
|
| 94 |
+
# weight has shape [vocab_size, hidden_size], hidden has shape [num_tokens, hidden_size]
|
| 95 |
+
ctx.original_hidden_shape = hidden.shape
|
| 96 |
+
if len(hidden.shape) > 2:
|
| 97 |
+
hidden = hidden.view(-1, hidden.shape[-1]) # [num_tokens, hidden_size]
|
| 98 |
+
if len(labels.shape) > 1:
|
| 99 |
+
labels = labels.view(-1)
|
| 100 |
+
|
| 101 |
+
logits = torch.matmul(hidden.to(torch.float32), weight.to(torch.float32).T) # [num_tokens, vocab_size]
|
| 102 |
+
logits /= temperature
|
| 103 |
+
whole_logits = torch.empty(
|
| 104 |
+
(logits.shape[0], logits.shape[1] * dist.get_world_size(dist_process_group)),
|
| 105 |
+
dtype=logits.dtype,
|
| 106 |
+
device=logits.device,
|
| 107 |
+
)
|
| 108 |
+
whole_logits_ref = [
|
| 109 |
+
whole_logits[:, i * logits.shape[1] : (i + 1) * logits.shape[1]]
|
| 110 |
+
for i in range(dist.get_world_size(dist_process_group))
|
| 111 |
+
]
|
| 112 |
+
dist.all_gather(whole_logits_ref, logits, group=dist_process_group)
|
| 113 |
+
|
| 114 |
+
pd = torch.nn.functional.softmax(whole_logits, dim=-1)
|
| 115 |
+
entropy_a = torch.logsumexp(whole_logits, dim=-1) # [num_tokens]
|
| 116 |
+
entropy_b = torch.sum(pd * whole_logits, dim=-1) # [num_tokens]
|
| 117 |
+
entropy = entropy_a - entropy_b
|
| 118 |
+
|
| 119 |
+
logprobs = torch.nn.functional.cross_entropy(whole_logits, labels, reduction="none")
|
| 120 |
+
logprobs = torch.neg(logprobs)
|
| 121 |
+
|
| 122 |
+
ctx.save_for_backward(hidden, weight, labels, whole_logits, entropy_b)
|
| 123 |
+
ctx.dist_process_group = dist_process_group
|
| 124 |
+
ctx.temperature = temperature
|
| 125 |
+
return logprobs, entropy
|
| 126 |
+
|
| 127 |
+
@staticmethod
|
| 128 |
+
def backward(ctx, g_logprobs: torch.Tensor, g_entropy: torch.Tensor):
|
| 129 |
+
hidden, weight, labels, whole_logits, entropy_b = ctx.saved_tensors
|
| 130 |
+
dist_process_group = ctx.dist_process_group
|
| 131 |
+
temperature = ctx.temperature
|
| 132 |
+
batch_size, hidden_size = hidden.shape
|
| 133 |
+
vocab_size, hidden_size = weight.shape
|
| 134 |
+
rank = dist.get_rank(dist_process_group)
|
| 135 |
+
|
| 136 |
+
# Compute softmax probabilities
|
| 137 |
+
maximum, _ = torch.max(whole_logits, dim=-1, keepdim=True)
|
| 138 |
+
exp_logits = torch.exp(whole_logits - maximum)
|
| 139 |
+
accumulate = exp_logits.sum(dim=-1, keepdim=True)
|
| 140 |
+
pd = exp_logits / accumulate
|
| 141 |
+
|
| 142 |
+
# Gradient for entropy
|
| 143 |
+
# entropy = entropy_a - entropy_b
|
| 144 |
+
# entropy_a = log(sum(exp(logits)))
|
| 145 |
+
# entropy_b = sum(pd * logits)
|
| 146 |
+
# d_entropy_a/d_logits = pd
|
| 147 |
+
# d_entropy_b/d_logits = pd * (logits - b.unsqueeze(1) + 1)
|
| 148 |
+
# d_entropy/d_logits = d_entropy_a - d_entropy_b
|
| 149 |
+
# d_entropy/d_logits = pd - pd * (logits - b.unsqueeze(1) + 1)
|
| 150 |
+
# d_entropy/d_logits = -pd * (logits - b.unsqueeze(1))
|
| 151 |
+
d_logits_entropy = g_entropy.unsqueeze(1) * (-pd * (whole_logits - entropy_b.unsqueeze(1)))
|
| 152 |
+
|
| 153 |
+
# Gradient for logprobs
|
| 154 |
+
# logprobs = -cross_entropy = -log(pd[labels])
|
| 155 |
+
# d_logprobs/d_logits = (pd - one_hot(labels))
|
| 156 |
+
one_hot = torch.zeros_like(whole_logits)
|
| 157 |
+
one_hot.scatter_(1, labels.unsqueeze(1), 1)
|
| 158 |
+
g_logprobs = torch.neg(g_logprobs)
|
| 159 |
+
d_logits_logprobs = g_logprobs.unsqueeze(1) * (pd - one_hot)
|
| 160 |
+
# NOTE: This will lead to wrong result
|
| 161 |
+
# d_logits_logprobs = g_logprobs.unsqueeze(1) * (pd - 1) * one_hot
|
| 162 |
+
|
| 163 |
+
# Combine gradients
|
| 164 |
+
d_logits = d_logits_entropy + d_logits_logprobs
|
| 165 |
+
d_logits /= temperature
|
| 166 |
+
|
| 167 |
+
# Get local slice of gradients
|
| 168 |
+
local_d_logits = d_logits[:, rank * vocab_size : (rank + 1) * vocab_size]
|
| 169 |
+
|
| 170 |
+
# Compute gradients for hidden and weight
|
| 171 |
+
d_hidden = torch.matmul(local_d_logits, weight.to(torch.float32))
|
| 172 |
+
d_weight = torch.matmul(local_d_logits.T, hidden.to(torch.float32))
|
| 173 |
+
d_hidden = d_hidden.view(ctx.original_hidden_shape)
|
| 174 |
+
|
| 175 |
+
return d_hidden, d_weight, None, None, None
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
run_torch_entropy_tp = TorchEntropyTP.apply
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class TestLinearCrossEntropy_TensorParallel:
|
| 182 |
+
def __init__(self):
|
| 183 |
+
dist.init_process_group(backend="nccl")
|
| 184 |
+
self.group = dist.group.WORLD
|
| 185 |
+
|
| 186 |
+
self.local_rank = dist.get_rank(self.group)
|
| 187 |
+
self.world_size = dist.get_world_size(self.group)
|
| 188 |
+
device = torch.device(f"cuda:{self.local_rank}")
|
| 189 |
+
torch.cuda.set_device(device)
|
| 190 |
+
print(f"[INFO]: Local rank: {self.local_rank}, World size: {self.world_size}")
|
| 191 |
+
|
| 192 |
+
def initialize(self, test_case_idx: int, temperature: float = 1.5):
|
| 193 |
+
self.test_case_idx = test_case_idx
|
| 194 |
+
self.temperature = temperature
|
| 195 |
+
|
| 196 |
+
def shutdown(self):
|
| 197 |
+
dist.destroy_process_group()
|
| 198 |
+
|
| 199 |
+
def cleanup(self):
|
| 200 |
+
torch.cuda.empty_cache()
|
| 201 |
+
torch.cuda.reset_peak_memory_stats()
|
| 202 |
+
import gc
|
| 203 |
+
|
| 204 |
+
gc.collect()
|
| 205 |
+
torch.cuda.synchronize()
|
| 206 |
+
|
| 207 |
+
def generate_hyper(self):
|
| 208 |
+
global LOW_MEMORY, LOW_MEMORY_DIV_FACTOR, MAX_TEST_CASES
|
| 209 |
+
|
| 210 |
+
self.dtype = torch.bfloat16
|
| 211 |
+
if self.test_case_idx == 0:
|
| 212 |
+
self.batch_size = 1
|
| 213 |
+
self.num_tokens = 1937
|
| 214 |
+
self.hidden_size = 3584
|
| 215 |
+
self.vocab_size = 152064
|
| 216 |
+
elif self.test_case_idx == 1:
|
| 217 |
+
self.batch_size = 1
|
| 218 |
+
self.num_tokens = 2169
|
| 219 |
+
self.hidden_size = 896
|
| 220 |
+
self.vocab_size = 151936
|
| 221 |
+
elif self.test_case_idx == 2:
|
| 222 |
+
self.batch_size = 1
|
| 223 |
+
self.num_tokens = 1530
|
| 224 |
+
self.hidden_size = 2048
|
| 225 |
+
self.vocab_size = 32256
|
| 226 |
+
elif self.test_case_idx == 3:
|
| 227 |
+
self.batch_size = 1
|
| 228 |
+
self.num_tokens = 1388
|
| 229 |
+
self.hidden_size = 4096
|
| 230 |
+
self.vocab_size = 102400
|
| 231 |
+
elif self.test_case_idx == 4:
|
| 232 |
+
self.batch_size = 1
|
| 233 |
+
self.num_tokens = 8192
|
| 234 |
+
self.hidden_size = 4096
|
| 235 |
+
self.vocab_size = 102400
|
| 236 |
+
else:
|
| 237 |
+
raise ValueError(f"Invalid test case index: {self.test_case_idx}")
|
| 238 |
+
if LOW_MEMORY:
|
| 239 |
+
self.vocab_size = int(self.vocab_size / LOW_MEMORY_DIV_FACTOR)
|
| 240 |
+
assert MAX_TEST_CASES <= 5, "MAX_TEST_CASES should be less than or equal to 5."
|
| 241 |
+
|
| 242 |
+
def generate_forward_inputs(self):
|
| 243 |
+
hidden = (
|
| 244 |
+
torch.empty((self.batch_size, self.num_tokens, self.hidden_size), dtype=self.dtype, device="cuda")
|
| 245 |
+
.uniform_(-0.5, 0.5)
|
| 246 |
+
.requires_grad_()
|
| 247 |
+
)
|
| 248 |
+
weight = (
|
| 249 |
+
torch.empty((self.vocab_size, self.hidden_size), dtype=self.dtype, device="cuda")
|
| 250 |
+
.uniform_(-0.5, 0.5)
|
| 251 |
+
.requires_grad_()
|
| 252 |
+
)
|
| 253 |
+
labels = torch.randint(0, self.vocab_size, (self.batch_size, self.num_tokens), device="cuda")
|
| 254 |
+
return hidden, weight, labels
|
| 255 |
+
|
| 256 |
+
def generate_backward_inputs(self):
|
| 257 |
+
g_entropy = torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-0.5, 0.5)
|
| 258 |
+
g_logprobs = torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-1, 1)
|
| 259 |
+
return g_entropy, g_logprobs
|
| 260 |
+
|
| 261 |
+
def verify_torch_itself(self, iterations: int = 5):
|
| 262 |
+
self.cleanup()
|
| 263 |
+
self.generate_hyper()
|
| 264 |
+
|
| 265 |
+
for i in range(iterations):
|
| 266 |
+
hidden, weight, labels = self.generate_forward_inputs()
|
| 267 |
+
|
| 268 |
+
# NOTE: we need to manually synchronize hidden and labels among Process Group
|
| 269 |
+
dist.broadcast(hidden, src=0, group=self.group)
|
| 270 |
+
dist.broadcast(labels, src=0, group=self.group)
|
| 271 |
+
|
| 272 |
+
# forward pass
|
| 273 |
+
# Create a tensor to hold the gathered weights from all ranks
|
| 274 |
+
# weight has shape [vocab_size, hidden_size]
|
| 275 |
+
# We want to gather along the first dimension to get [vocab_size * world_size, hidden_size]
|
| 276 |
+
|
| 277 |
+
# Create a single contiguous tensor to hold all gathered weights
|
| 278 |
+
whole_weight = torch.empty(
|
| 279 |
+
(self.vocab_size * self.world_size, self.hidden_size), dtype=weight.dtype, device=weight.device
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
# Create views into the tensor for each rank's portion
|
| 283 |
+
whole_weight_views = [
|
| 284 |
+
whole_weight[i * self.vocab_size : (i + 1) * self.vocab_size] for i in range(self.world_size)
|
| 285 |
+
]
|
| 286 |
+
|
| 287 |
+
# Perform all_gather operation using the views
|
| 288 |
+
dist.all_gather(whole_weight_views, weight, group=self.group)
|
| 289 |
+
|
| 290 |
+
# Set requires_grad for autograd
|
| 291 |
+
whole_weight.requires_grad_()
|
| 292 |
+
|
| 293 |
+
(single_logprobs, single_entropy) = run_torch_entropy(hidden, whole_weight, labels, self.temperature)
|
| 294 |
+
|
| 295 |
+
(tp_logprobs, tp_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group)
|
| 296 |
+
|
| 297 |
+
torch.testing.assert_close(single_logprobs, tp_logprobs, atol=1e-4, rtol=1e-4)
|
| 298 |
+
torch.testing.assert_close(single_entropy, tp_entropy, atol=1e-4, rtol=1e-4)
|
| 299 |
+
|
| 300 |
+
# backward pass
|
| 301 |
+
g_entropy, g_logprobs = self.generate_backward_inputs()
|
| 302 |
+
# NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group
|
| 303 |
+
dist.broadcast(g_entropy, src=0, group=self.group)
|
| 304 |
+
dist.broadcast(g_logprobs, src=0, group=self.group)
|
| 305 |
+
|
| 306 |
+
(single_d_hidden, single_d_weight) = torch.autograd.grad(
|
| 307 |
+
(single_entropy, single_logprobs), (hidden, whole_weight), (g_entropy, g_logprobs), retain_graph=False
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
(tp_d_hidden, tp_d_weight) = torch.autograd.grad(
|
| 311 |
+
(tp_entropy, tp_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
|
| 312 |
+
)
|
| 313 |
+
# NOTE: all-reduce on hidden is conducted outside the kernel
|
| 314 |
+
dist.all_reduce(tp_d_hidden, op=dist.ReduceOp.SUM, group=self.group)
|
| 315 |
+
|
| 316 |
+
torch.testing.assert_close(tp_d_hidden, single_d_hidden, atol=1e-2, rtol=1e-4)
|
| 317 |
+
# Extract the corresponding slice from single_d_weight for comparison
|
| 318 |
+
# tp_d_weight has shape [vocab_size, hidden_size]
|
| 319 |
+
# single_d_weight has shape [vocab_size * world_size, hidden_size]
|
| 320 |
+
torch.testing.assert_close(
|
| 321 |
+
tp_d_weight,
|
| 322 |
+
single_d_weight[self.local_rank * self.vocab_size : (self.local_rank + 1) * self.vocab_size],
|
| 323 |
+
atol=1e-2,
|
| 324 |
+
rtol=1e-4,
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
# atol=1e-3, rtol=1e-4)
|
| 328 |
+
if self.local_rank == 0:
|
| 329 |
+
print("[PASS] torch TP correctness is verified")
|
| 330 |
+
|
| 331 |
+
def check_torch_storage(self):
|
| 332 |
+
self.cleanup()
|
| 333 |
+
self.generate_hyper()
|
| 334 |
+
|
| 335 |
+
hidden, weight, labels = self.generate_forward_inputs()
|
| 336 |
+
|
| 337 |
+
# NOTE: we need to manually synchronize hidden and labels among Process Group
|
| 338 |
+
dist.broadcast(hidden, src=0, group=self.group)
|
| 339 |
+
dist.broadcast(labels, src=0, group=self.group)
|
| 340 |
+
|
| 341 |
+
torch.cuda.reset_peak_memory_stats()
|
| 342 |
+
(tp_logprobs, tp_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group)
|
| 343 |
+
torch.cuda.synchronize()
|
| 344 |
+
forward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024
|
| 345 |
+
|
| 346 |
+
g_entropy, g_logprobs = self.generate_backward_inputs()
|
| 347 |
+
# NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group
|
| 348 |
+
dist.broadcast(g_entropy, src=0, group=self.group)
|
| 349 |
+
dist.broadcast(g_logprobs, src=0, group=self.group)
|
| 350 |
+
|
| 351 |
+
torch.cuda.reset_peak_memory_stats()
|
| 352 |
+
(d_tp_hidden, d_tp_weight) = torch.autograd.grad(
|
| 353 |
+
(tp_entropy, tp_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
|
| 354 |
+
)
|
| 355 |
+
torch.cuda.synchronize()
|
| 356 |
+
backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024
|
| 357 |
+
# NOTE: all-reduce on hidden is conducted outside the kernel
|
| 358 |
+
dist.all_reduce(d_tp_hidden, op=dist.ReduceOp.SUM, group=self.group)
|
| 359 |
+
|
| 360 |
+
if self.local_rank == 0:
|
| 361 |
+
print(f"[INFO]: Torch Forward pass peak memory: {forward_max_memory:.2f} MB")
|
| 362 |
+
print(f"[INFO]: Torch Backward pass peak memory: {backward_max_memory:.2f} MB")
|
| 363 |
+
|
| 364 |
+
def verify_kernel_correctness(self, iterations: int = 5):
|
| 365 |
+
self.cleanup()
|
| 366 |
+
self.generate_hyper()
|
| 367 |
+
|
| 368 |
+
torch_forward_latency = list()
|
| 369 |
+
torch_backward_latency = list()
|
| 370 |
+
kernel_forward_latency = list()
|
| 371 |
+
kernel_backward_latency = list()
|
| 372 |
+
|
| 373 |
+
start_event = torch.cuda.Event(enable_timing=True)
|
| 374 |
+
end_event = torch.cuda.Event(enable_timing=True)
|
| 375 |
+
|
| 376 |
+
for i in range(iterations):
|
| 377 |
+
hidden, weight, labels = self.generate_forward_inputs()
|
| 378 |
+
|
| 379 |
+
# NOTE: we need to manually synchronize hidden and labels among Process Group
|
| 380 |
+
dist.broadcast(hidden, src=0, group=self.group)
|
| 381 |
+
dist.broadcast(labels, src=0, group=self.group)
|
| 382 |
+
|
| 383 |
+
start_event.record()
|
| 384 |
+
(torch_logprobs, torch_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group)
|
| 385 |
+
end_event.record()
|
| 386 |
+
torch.cuda.synchronize()
|
| 387 |
+
torch_forward_latency.append(start_event.elapsed_time(end_event))
|
| 388 |
+
|
| 389 |
+
start_event.record()
|
| 390 |
+
(kernel_logprobs, kernel_entropy) = linear_cross_entropy(
|
| 391 |
+
hidden, weight, labels, self.temperature, "none", self.group
|
| 392 |
+
)
|
| 393 |
+
end_event.record()
|
| 394 |
+
torch.cuda.synchronize()
|
| 395 |
+
kernel_forward_latency.append(start_event.elapsed_time(end_event))
|
| 396 |
+
|
| 397 |
+
torch.testing.assert_close(torch_logprobs, kernel_logprobs, atol=1e-1, rtol=1e-2)
|
| 398 |
+
torch.testing.assert_close(torch_entropy, kernel_entropy, atol=1e-1, rtol=1e-2)
|
| 399 |
+
|
| 400 |
+
# backward pass
|
| 401 |
+
g_entropy, g_logprobs = self.generate_backward_inputs()
|
| 402 |
+
# NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group
|
| 403 |
+
dist.broadcast(g_entropy, src=0, group=self.group)
|
| 404 |
+
dist.broadcast(g_logprobs, src=0, group=self.group)
|
| 405 |
+
|
| 406 |
+
start_event.record()
|
| 407 |
+
(torch_d_hidden, torch_d_weight) = torch.autograd.grad(
|
| 408 |
+
(torch_entropy, torch_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
|
| 409 |
+
)
|
| 410 |
+
end_event.record()
|
| 411 |
+
torch.cuda.synchronize()
|
| 412 |
+
torch_backward_latency.append(start_event.elapsed_time(end_event))
|
| 413 |
+
# NOTE: all-reduce on hidden is conducted outside the kernel
|
| 414 |
+
dist.all_reduce(torch_d_hidden, op=dist.ReduceOp.SUM, group=self.group)
|
| 415 |
+
|
| 416 |
+
start_event.record()
|
| 417 |
+
(kernel_d_hidden, kernel_d_weight) = torch.autograd.grad(
|
| 418 |
+
(kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
|
| 419 |
+
)
|
| 420 |
+
end_event.record()
|
| 421 |
+
torch.cuda.synchronize()
|
| 422 |
+
kernel_backward_latency.append(start_event.elapsed_time(end_event))
|
| 423 |
+
# NOTE: all-reduce on hidden is conducted outside the kernel
|
| 424 |
+
dist.all_reduce(kernel_d_hidden, op=dist.ReduceOp.SUM, group=self.group)
|
| 425 |
+
|
| 426 |
+
torch.testing.assert_close(torch_d_hidden, kernel_d_hidden, atol=2e-2, rtol=4e-2)
|
| 427 |
+
torch.testing.assert_close(torch_d_weight, kernel_d_weight, atol=2e-2, rtol=4e-2)
|
| 428 |
+
|
| 429 |
+
# remove first latency
|
| 430 |
+
torch_forward_latency = torch_forward_latency[1:]
|
| 431 |
+
torch_backward_latency = torch_backward_latency[1:]
|
| 432 |
+
kernel_forward_latency = kernel_forward_latency[1:]
|
| 433 |
+
kernel_backward_latency = kernel_backward_latency[1:]
|
| 434 |
+
|
| 435 |
+
if self.local_rank == 0:
|
| 436 |
+
print("\n[PASS]: Verified kernel forward & backward correctness.")
|
| 437 |
+
|
| 438 |
+
print(
|
| 439 |
+
f"[INFO]: Forward pass: Torch implementation average time: "
|
| 440 |
+
f"{sum(torch_forward_latency) / len(torch_forward_latency):.2f} ms"
|
| 441 |
+
)
|
| 442 |
+
print(
|
| 443 |
+
f"[INFO]: Backward pass: torch implementation average time: "
|
| 444 |
+
f"{sum(torch_backward_latency) / len(torch_backward_latency):.2f} ms"
|
| 445 |
+
)
|
| 446 |
+
print(
|
| 447 |
+
f"[INFO]: Forward pass: Kernel implementation average time: "
|
| 448 |
+
f"{sum(kernel_forward_latency) / len(kernel_forward_latency):.2f} ms"
|
| 449 |
+
)
|
| 450 |
+
print(
|
| 451 |
+
f"[INFO]: Backward pass: kernel implementation average time: "
|
| 452 |
+
f"{sum(kernel_backward_latency) / len(kernel_backward_latency):.2f} ms"
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
def check_kernel_storage(self):
|
| 456 |
+
self.cleanup()
|
| 457 |
+
self.generate_hyper()
|
| 458 |
+
|
| 459 |
+
hidden, weight, labels = self.generate_forward_inputs()
|
| 460 |
+
|
| 461 |
+
# NOTE: we need to manually synchronize hidden and labels among Process Group
|
| 462 |
+
dist.broadcast(hidden, src=0, group=self.group)
|
| 463 |
+
dist.broadcast(labels, src=0, group=self.group)
|
| 464 |
+
|
| 465 |
+
torch.cuda.reset_peak_memory_stats()
|
| 466 |
+
(kernel_logprobs, kernel_entropy) = linear_cross_entropy(
|
| 467 |
+
hidden, weight, labels, self.temperature, "none", self.group
|
| 468 |
+
)
|
| 469 |
+
torch.cuda.synchronize()
|
| 470 |
+
kernel_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024
|
| 471 |
+
|
| 472 |
+
g_entropy, g_logprobs = self.generate_backward_inputs()
|
| 473 |
+
# NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group
|
| 474 |
+
dist.broadcast(g_entropy, src=0, group=self.group)
|
| 475 |
+
dist.broadcast(g_logprobs, src=0, group=self.group)
|
| 476 |
+
|
| 477 |
+
torch.cuda.reset_peak_memory_stats()
|
| 478 |
+
(d_kernel_hidden, d_kernel_weight) = torch.autograd.grad(
|
| 479 |
+
(kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
|
| 480 |
+
)
|
| 481 |
+
torch.cuda.synchronize()
|
| 482 |
+
kernel_backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024
|
| 483 |
+
# NOTE: all-reduce on hidden is conducted outside the kernel
|
| 484 |
+
dist.all_reduce(d_kernel_hidden, op=dist.ReduceOp.SUM, group=self.group)
|
| 485 |
+
|
| 486 |
+
if self.local_rank == 0:
|
| 487 |
+
print(f"[INFO]: Kernel Forward pass peak memory: {kernel_max_memory:.2f} MB")
|
| 488 |
+
print(f"[INFO]: Kernel Backward pass peak memory: {kernel_backward_max_memory:.2f} MB")
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
if __name__ == "__main__":
|
| 492 |
+
# TP command: torchrun --standalone --nnodes=1 --nproc-per-node=2 tests/kernels/test_linear_cross_entropy_tp.py
|
| 493 |
+
|
| 494 |
+
# Check if running with torchrun (distributed mode)
|
| 495 |
+
assert int(os.environ["WORLD_SIZE"]) > 1, (
|
| 496 |
+
"[ERROR]: This test is designed to run in distributed mode with torchrun. Please use torchrun to "
|
| 497 |
+
"execute this script."
|
| 498 |
+
)
|
| 499 |
+
torch.manual_seed(233376 + int(os.environ.get("RANK", 0)))
|
| 500 |
+
|
| 501 |
+
# set_backward_method(BackwardEnum._Total_Fuse_MN)
|
| 502 |
+
# set_backward_method(BackwardEnum._Split_Dlogits_N)
|
| 503 |
+
|
| 504 |
+
test = TestLinearCrossEntropy_TensorParallel()
|
| 505 |
+
for test_case_idx in range(MAX_TEST_CASES):
|
| 506 |
+
print(f"[INFO] Running test case {test_case_idx}")
|
| 507 |
+
test.initialize(test_case_idx)
|
| 508 |
+
if VERIFY_TORCH_SELF:
|
| 509 |
+
test.verify_torch_itself()
|
| 510 |
+
test.check_torch_storage()
|
| 511 |
+
test.verify_kernel_correctness()
|
| 512 |
+
test.check_kernel_storage()
|
| 513 |
+
|
| 514 |
+
test.shutdown()
|
code/RL_model/verl/verl_train/tests/utils/test_special_mstx_profile.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import unittest
|
| 16 |
+
from unittest.mock import MagicMock, patch
|
| 17 |
+
|
| 18 |
+
from verl.utils.profiler.config import NPUToolConfig, ProfilerConfig
|
| 19 |
+
from verl.utils.profiler.mstx_profile import NPUProfiler
|
| 20 |
+
from verl.utils.profiler.profile import DistProfiler
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class TestNPUProfilerInitialization(unittest.TestCase):
|
| 24 |
+
def setUp(self):
|
| 25 |
+
NPUProfiler._define_count = 0
|
| 26 |
+
|
| 27 |
+
def test_init_with_default_config(self):
|
| 28 |
+
tool_config = NPUToolConfig()
|
| 29 |
+
config = ProfilerConfig(tool="npu")
|
| 30 |
+
profiler = DistProfiler(rank=0, config=config, tool_config=tool_config)
|
| 31 |
+
self.assertFalse(profiler.check_enable())
|
| 32 |
+
|
| 33 |
+
def test_init_with_disabled_config(self):
|
| 34 |
+
config = ProfilerConfig(enable=False, tool="npu")
|
| 35 |
+
tool_config = NPUToolConfig()
|
| 36 |
+
profiler = DistProfiler(rank=0, config=config, tool_config=tool_config)
|
| 37 |
+
self.assertFalse(profiler.check_enable())
|
| 38 |
+
|
| 39 |
+
def test_init_with_all_ranks_true(self):
|
| 40 |
+
config = ProfilerConfig(enable=True, all_ranks=True, tool="npu")
|
| 41 |
+
tool_config = NPUToolConfig()
|
| 42 |
+
profiler = DistProfiler(rank=0, config=config, tool_config=tool_config)
|
| 43 |
+
self.assertTrue(profiler.check_this_rank())
|
| 44 |
+
|
| 45 |
+
def test_init_with_ranks_list(self):
|
| 46 |
+
config = ProfilerConfig(enable=True, ranks=[1, 2], tool="npu")
|
| 47 |
+
tool_config = NPUToolConfig()
|
| 48 |
+
profiler = DistProfiler(rank=1, config=config, tool_config=tool_config)
|
| 49 |
+
self.assertTrue(profiler.check_this_rank())
|
| 50 |
+
|
| 51 |
+
def test_init_with_rank_not_in_ranks(self):
|
| 52 |
+
config = ProfilerConfig(enable=True, ranks=[1, 2], tool="npu")
|
| 53 |
+
tool_config = NPUToolConfig()
|
| 54 |
+
profiler = DistProfiler(rank=3, config=config, tool_config=tool_config)
|
| 55 |
+
self.assertFalse(profiler.check_this_rank())
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class TestNPUProfilerStart(unittest.TestCase):
|
| 59 |
+
def setUp(self):
|
| 60 |
+
NPUProfiler._define_count = 0
|
| 61 |
+
self.config = ProfilerConfig(enable=True, ranks=[0], tool="npu")
|
| 62 |
+
self.tool_config = NPUToolConfig(discrete=False)
|
| 63 |
+
|
| 64 |
+
@patch("verl.utils.profiler.mstx_profile.get_npu_profiler")
|
| 65 |
+
def test_start_when_enabled_and_this_rank(self, mock_get_profiler):
|
| 66 |
+
profiler = DistProfiler(rank=0, config=self.config, tool_config=self.tool_config)
|
| 67 |
+
profiler.start(role="worker", profile_step="1")
|
| 68 |
+
self.assertTrue(profiler.check_this_step())
|
| 69 |
+
self.assertEqual(NPUProfiler._define_count, 1)
|
| 70 |
+
mock_get_profiler.assert_called_once()
|
| 71 |
+
|
| 72 |
+
@patch("verl.utils.profiler.mstx_profile.get_npu_profiler")
|
| 73 |
+
def test_start_when_not_this_rank(self, mock_get_profiler):
|
| 74 |
+
profiler = DistProfiler(rank=1, config=self.config, tool_config=self.tool_config)
|
| 75 |
+
profiler.start()
|
| 76 |
+
self.assertFalse(profiler.check_this_step())
|
| 77 |
+
self.assertEqual(NPUProfiler._define_count, 0)
|
| 78 |
+
mock_get_profiler.assert_not_called()
|
| 79 |
+
|
| 80 |
+
@patch("verl.utils.profiler.mstx_profile.get_npu_profiler")
|
| 81 |
+
def test_start_discrete_mode_does_not_increase_count(self, mock_get_profiler):
|
| 82 |
+
tool_config = NPUToolConfig(discrete=True)
|
| 83 |
+
profiler = DistProfiler(rank=0, config=self.config, tool_config=tool_config)
|
| 84 |
+
profiler.start()
|
| 85 |
+
self.assertEqual(NPUProfiler._define_count, 0)
|
| 86 |
+
mock_get_profiler.assert_not_called()
|
| 87 |
+
|
| 88 |
+
@patch("verl.utils.profiler.mstx_profile.get_npu_profiler")
|
| 89 |
+
def test_multiple_start_calls_do_not_increase_count(self, mock_get_profiler):
|
| 90 |
+
profiler = DistProfiler(rank=0, config=self.config, tool_config=self.tool_config)
|
| 91 |
+
profiler.start()
|
| 92 |
+
profiler.start()
|
| 93 |
+
self.assertEqual(NPUProfiler._define_count, 1)
|
| 94 |
+
mock_get_profiler.assert_called_once()
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class TestNPUProfilerStartStopInteraction(unittest.TestCase):
|
| 98 |
+
def setUp(self):
|
| 99 |
+
NPUProfiler._define_count = 0
|
| 100 |
+
self.config = ProfilerConfig(enable=True, ranks=[0], tool="npu")
|
| 101 |
+
self.tool_config = NPUToolConfig(discrete=False)
|
| 102 |
+
|
| 103 |
+
@patch("verl.utils.profiler.mstx_profile.get_npu_profiler")
|
| 104 |
+
def test_start_stop_cycle(self, mock_get_profiler):
|
| 105 |
+
mock_profile_npu = MagicMock()
|
| 106 |
+
mock_get_profiler.return_value = mock_profile_npu
|
| 107 |
+
|
| 108 |
+
profiler = DistProfiler(rank=0, config=self.config, tool_config=self.tool_config)
|
| 109 |
+
profiler.start()
|
| 110 |
+
self.assertEqual(NPUProfiler._define_count, 1)
|
| 111 |
+
self.assertEqual(mock_profile_npu.start.call_count, 1)
|
| 112 |
+
profiler.stop()
|
| 113 |
+
self.assertEqual(NPUProfiler._define_count, 0)
|
| 114 |
+
self.assertEqual(mock_profile_npu.step.call_count, 1)
|
| 115 |
+
self.assertEqual(mock_profile_npu.stop.call_count, 1)
|
| 116 |
+
|
| 117 |
+
@patch("verl.utils.profiler.mstx_profile.get_npu_profiler")
|
| 118 |
+
def test_multiple_instances_share_define_count(self, mock_get_profiler):
|
| 119 |
+
mock_profile_npu = MagicMock()
|
| 120 |
+
mock_get_profiler.return_value = mock_profile_npu
|
| 121 |
+
|
| 122 |
+
profiler1 = DistProfiler(rank=0, config=self.config, tool_config=self.tool_config)
|
| 123 |
+
profiler2 = DistProfiler(rank=0, config=self.config, tool_config=self.tool_config)
|
| 124 |
+
profiler1.start()
|
| 125 |
+
profiler2.start()
|
| 126 |
+
self.assertEqual(NPUProfiler._define_count, 1)
|
| 127 |
+
self.assertEqual(mock_profile_npu.start.call_count, 1)
|
| 128 |
+
profiler1.stop()
|
| 129 |
+
self.assertEqual(NPUProfiler._define_count, 0)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class TestNPUProfilerAnnotate(unittest.TestCase):
|
| 133 |
+
def setUp(self):
|
| 134 |
+
self.config = ProfilerConfig(enable=True, all_ranks=True, tool="npu")
|
| 135 |
+
self.tool_config = NPUToolConfig(discrete=False)
|
| 136 |
+
self.rank = 0
|
| 137 |
+
|
| 138 |
+
def test_annotate_decorator_applied_correctly(self):
|
| 139 |
+
mock_worker = MagicMock()
|
| 140 |
+
mock_worker.profiler = DistProfiler(rank=self.rank, config=self.config, tool_config=self.tool_config)
|
| 141 |
+
# Manually set private attribute for testing annotation in active step
|
| 142 |
+
mock_worker.profiler._this_step = True
|
| 143 |
+
|
| 144 |
+
mock_mark_range = "mocked_range_handle"
|
| 145 |
+
|
| 146 |
+
with (
|
| 147 |
+
patch("verl.utils.profiler.mstx_profile.mark_start_range") as mock_start_patch,
|
| 148 |
+
patch("verl.utils.profiler.mstx_profile.mark_end_range") as mock_end_patch,
|
| 149 |
+
):
|
| 150 |
+
mock_start_patch.return_value = mock_mark_range
|
| 151 |
+
|
| 152 |
+
with patch("verl.utils.profiler.mstx_profile.get_npu_profiler") as mock_get_profiler:
|
| 153 |
+
decorator = mock_worker.profiler.annotate(message="test")
|
| 154 |
+
|
| 155 |
+
@decorator
|
| 156 |
+
def test_func(self, *args, **kwargs):
|
| 157 |
+
return "result"
|
| 158 |
+
|
| 159 |
+
result = test_func(mock_worker)
|
| 160 |
+
|
| 161 |
+
self.assertEqual(result, "result")
|
| 162 |
+
mock_start_patch.assert_called_once_with(message="test")
|
| 163 |
+
mock_end_patch.assert_called_once_with(mock_mark_range)
|
| 164 |
+
mock_get_profiler.assert_not_called()
|
| 165 |
+
|
| 166 |
+
def test_annotate_when_profiler_disabled(self):
|
| 167 |
+
disabled_config = ProfilerConfig(enable=False, tool="npu")
|
| 168 |
+
mock_worker = MagicMock()
|
| 169 |
+
mock_worker.profiler = DistProfiler(rank=self.rank, config=disabled_config, tool_config=self.tool_config)
|
| 170 |
+
|
| 171 |
+
with (
|
| 172 |
+
patch("verl.utils.profiler.mstx_profile.mark_start_range") as mock_start_patch,
|
| 173 |
+
patch("verl.utils.profiler.mstx_profile.mark_end_range") as mock_end_patch,
|
| 174 |
+
patch("verl.utils.profiler.mstx_profile.get_npu_profiler") as mock_get_profiler,
|
| 175 |
+
):
|
| 176 |
+
decorator = mock_worker.profiler.annotate(message="test")
|
| 177 |
+
|
| 178 |
+
@decorator
|
| 179 |
+
def test_func(self, *args, **kwargs):
|
| 180 |
+
return "result"
|
| 181 |
+
|
| 182 |
+
result = test_func(mock_worker)
|
| 183 |
+
|
| 184 |
+
self.assertEqual(result, "result")
|
| 185 |
+
mock_start_patch.assert_not_called()
|
| 186 |
+
mock_end_patch.assert_not_called()
|
| 187 |
+
mock_get_profiler.assert_not_called()
|
| 188 |
+
|
| 189 |
+
def test_annotate_when_this_step_disabled(self):
|
| 190 |
+
mock_worker = MagicMock()
|
| 191 |
+
mock_worker.profiler = DistProfiler(rank=self.rank, config=self.config, tool_config=self.tool_config)
|
| 192 |
+
mock_worker.profiler._this_step = False
|
| 193 |
+
|
| 194 |
+
with (
|
| 195 |
+
patch("verl.utils.profiler.mstx_profile.mark_start_range") as mock_start_patch,
|
| 196 |
+
patch("verl.utils.profiler.mstx_profile.mark_end_range") as mock_end_patch,
|
| 197 |
+
patch("verl.utils.profiler.mstx_profile.get_npu_profiler") as mock_get_profiler,
|
| 198 |
+
):
|
| 199 |
+
decorator = mock_worker.profiler.annotate(message="test")
|
| 200 |
+
|
| 201 |
+
@decorator
|
| 202 |
+
def test_func(self, *args, **kwargs):
|
| 203 |
+
return "result"
|
| 204 |
+
|
| 205 |
+
result = test_func(mock_worker)
|
| 206 |
+
|
| 207 |
+
self.assertEqual(result, "result")
|
| 208 |
+
mock_start_patch.assert_not_called()
|
| 209 |
+
mock_end_patch.assert_not_called()
|
| 210 |
+
mock_get_profiler.assert_not_called()
|
| 211 |
+
|
| 212 |
+
def test_annotate_discrete_mode_enabled(self):
|
| 213 |
+
discrete_tool_config = NPUToolConfig(discrete=True)
|
| 214 |
+
mock_worker = MagicMock()
|
| 215 |
+
mock_worker.profiler = DistProfiler(rank=self.rank, config=self.config, tool_config=discrete_tool_config)
|
| 216 |
+
mock_worker.profiler._this_step = True
|
| 217 |
+
|
| 218 |
+
mock_mark_range = "mocked_range_handle"
|
| 219 |
+
mock_profile_npu = MagicMock()
|
| 220 |
+
|
| 221 |
+
with (
|
| 222 |
+
patch("verl.utils.profiler.mstx_profile.mark_start_range") as mock_start_patch,
|
| 223 |
+
patch("verl.utils.profiler.mstx_profile.mark_end_range") as mock_end_patch,
|
| 224 |
+
patch("verl.utils.profiler.mstx_profile.get_npu_profiler") as mock_get_profiler,
|
| 225 |
+
):
|
| 226 |
+
mock_start_patch.return_value = mock_mark_range
|
| 227 |
+
mock_get_profiler.return_value = mock_profile_npu
|
| 228 |
+
decorator = mock_worker.profiler.annotate(message="test", role="test_role")
|
| 229 |
+
|
| 230 |
+
@decorator
|
| 231 |
+
def test_func(self, *args, **kwargs):
|
| 232 |
+
return "result"
|
| 233 |
+
|
| 234 |
+
result = test_func(mock_worker)
|
| 235 |
+
|
| 236 |
+
self.assertEqual(result, "result")
|
| 237 |
+
mock_start_patch.assert_called_once_with(message="test")
|
| 238 |
+
mock_end_patch.assert_called_once_with(mock_mark_range)
|
| 239 |
+
mock_get_profiler.assert_called_once_with(
|
| 240 |
+
contents=mock_worker.profiler._impl.profile_contents,
|
| 241 |
+
profile_level=mock_worker.profiler._impl.profile_level,
|
| 242 |
+
profile_save_path=mock_worker.profiler._impl.profile_save_path,
|
| 243 |
+
analysis=mock_worker.profiler._impl.analysis,
|
| 244 |
+
role="test_role",
|
| 245 |
+
)
|
| 246 |
+
mock_profile_npu.start.assert_called_once()
|
| 247 |
+
mock_profile_npu.step.assert_called_once()
|
| 248 |
+
mock_profile_npu.stop.assert_called_once()
|
| 249 |
+
|
| 250 |
+
def test_annotate_with_default_message(self):
|
| 251 |
+
mock_worker = MagicMock()
|
| 252 |
+
mock_worker.profiler = DistProfiler(rank=self.rank, config=self.config, tool_config=self.tool_config)
|
| 253 |
+
mock_worker.profiler._this_step = True
|
| 254 |
+
|
| 255 |
+
mock_mark_range = "mocked_range_handle"
|
| 256 |
+
with (
|
| 257 |
+
patch("verl.utils.profiler.mstx_profile.mark_start_range") as mock_start_patch,
|
| 258 |
+
patch("verl.utils.profiler.mstx_profile.mark_end_range") as mock_end_patch,
|
| 259 |
+
):
|
| 260 |
+
mock_start_patch.return_value = mock_mark_range
|
| 261 |
+
decorator = mock_worker.profiler.annotate()
|
| 262 |
+
|
| 263 |
+
@decorator
|
| 264 |
+
def test_func(self, *args, **kwargs):
|
| 265 |
+
return "result"
|
| 266 |
+
|
| 267 |
+
test_func(mock_worker)
|
| 268 |
+
|
| 269 |
+
mock_start_patch.assert_called_once_with(message="test_func")
|
| 270 |
+
mock_end_patch.assert_called_once_with(mock_mark_range)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
if __name__ == "__main__":
|
| 274 |
+
unittest.main()
|
code/RL_model/verl/verl_train/tests/utils/test_temp_env_on_cpu.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
import pytest
|
| 18 |
+
|
| 19 |
+
from verl.utils.py_functional import temp_env_var
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@pytest.fixture(autouse=True)
|
| 23 |
+
def clean_env():
|
| 24 |
+
"""Fixture to clean up environment variables before and after each test."""
|
| 25 |
+
# Store original environment state
|
| 26 |
+
original_env = dict(os.environ)
|
| 27 |
+
|
| 28 |
+
# Clean up any test variables that might exist
|
| 29 |
+
test_vars = ["TEST_VAR", "TEST_VAR_2", "EXISTING_VAR"]
|
| 30 |
+
for var in test_vars:
|
| 31 |
+
if var in os.environ:
|
| 32 |
+
del os.environ[var]
|
| 33 |
+
|
| 34 |
+
# Yield control to the test function
|
| 35 |
+
yield
|
| 36 |
+
|
| 37 |
+
# Restore original environment state after test
|
| 38 |
+
os.environ.clear()
|
| 39 |
+
os.environ.update(original_env)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def test_set_new_env_var():
|
| 43 |
+
"""Test setting a new environment variable that didn't exist before."""
|
| 44 |
+
# Ensure variable doesn't exist
|
| 45 |
+
assert "TEST_VAR" not in os.environ
|
| 46 |
+
|
| 47 |
+
with temp_env_var("TEST_VAR", "test_value"):
|
| 48 |
+
# Variable should be set inside context
|
| 49 |
+
assert os.environ["TEST_VAR"] == "test_value"
|
| 50 |
+
assert "TEST_VAR" in os.environ
|
| 51 |
+
|
| 52 |
+
# Variable should be removed after context
|
| 53 |
+
assert "TEST_VAR" not in os.environ
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def test_restore_existing_env_var():
|
| 57 |
+
"""Test restoring an environment variable that already existed."""
|
| 58 |
+
# Set up existing variable
|
| 59 |
+
os.environ["EXISTING_VAR"] = "original_value"
|
| 60 |
+
|
| 61 |
+
with temp_env_var("EXISTING_VAR", "temporary_value"):
|
| 62 |
+
# Variable should be temporarily changed
|
| 63 |
+
assert os.environ["EXISTING_VAR"] == "temporary_value"
|
| 64 |
+
|
| 65 |
+
# Variable should be restored to original value
|
| 66 |
+
assert os.environ["EXISTING_VAR"] == "original_value"
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def test_env_var_restored_on_exception():
|
| 70 |
+
"""Test that environment variables are restored even when exceptions occur."""
|
| 71 |
+
# Set up existing variable
|
| 72 |
+
os.environ["EXISTING_VAR"] = "original_value"
|
| 73 |
+
|
| 74 |
+
with pytest.raises(ValueError):
|
| 75 |
+
with temp_env_var("EXISTING_VAR", "temporary_value"):
|
| 76 |
+
# Verify variable is set
|
| 77 |
+
assert os.environ["EXISTING_VAR"] == "temporary_value"
|
| 78 |
+
# Raise exception
|
| 79 |
+
raise ValueError("Test exception")
|
| 80 |
+
|
| 81 |
+
# Variable should still be restored despite exception
|
| 82 |
+
assert os.environ["EXISTING_VAR"] == "original_value"
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def test_nested_context_managers():
|
| 86 |
+
"""Test nested temp_env_var context managers."""
|
| 87 |
+
# Set up original variable
|
| 88 |
+
os.environ["TEST_VAR"] = "original"
|
| 89 |
+
|
| 90 |
+
with temp_env_var("TEST_VAR", "level1"):
|
| 91 |
+
assert os.environ["TEST_VAR"] == "level1"
|
| 92 |
+
|
| 93 |
+
with temp_env_var("TEST_VAR", "level2"):
|
| 94 |
+
assert os.environ["TEST_VAR"] == "level2"
|
| 95 |
+
|
| 96 |
+
# Should restore to level1
|
| 97 |
+
assert os.environ["TEST_VAR"] == "level1"
|
| 98 |
+
|
| 99 |
+
# Should restore to original
|
| 100 |
+
assert os.environ["TEST_VAR"] == "original"
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def test_multiple_different_vars():
|
| 104 |
+
"""Test setting multiple different environment variables."""
|
| 105 |
+
# Set up one existing variable
|
| 106 |
+
os.environ["EXISTING_VAR"] = "existing_value"
|
| 107 |
+
|
| 108 |
+
with temp_env_var("EXISTING_VAR", "modified"):
|
| 109 |
+
with temp_env_var("TEST_VAR", "new_value"):
|
| 110 |
+
assert os.environ["EXISTING_VAR"] == "modified"
|
| 111 |
+
assert os.environ["TEST_VAR"] == "new_value"
|
| 112 |
+
|
| 113 |
+
# Check restoration
|
| 114 |
+
assert os.environ["EXISTING_VAR"] == "existing_value"
|
| 115 |
+
assert "TEST_VAR" not in os.environ
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def test_empty_string_value():
|
| 119 |
+
"""Test setting environment variable to empty string."""
|
| 120 |
+
with temp_env_var("TEST_VAR", ""):
|
| 121 |
+
assert os.environ["TEST_VAR"] == ""
|
| 122 |
+
assert "TEST_VAR" in os.environ
|
| 123 |
+
|
| 124 |
+
# Should be removed after context
|
| 125 |
+
assert "TEST_VAR" not in os.environ
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def test_overwrite_with_empty_string():
|
| 129 |
+
"""Test overwriting existing variable with empty string."""
|
| 130 |
+
os.environ["EXISTING_VAR"] = "original"
|
| 131 |
+
|
| 132 |
+
with temp_env_var("EXISTING_VAR", ""):
|
| 133 |
+
assert os.environ["EXISTING_VAR"] == ""
|
| 134 |
+
|
| 135 |
+
# Should restore original value
|
| 136 |
+
assert os.environ["EXISTING_VAR"] == "original"
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def test_context_manager_returns_none():
|
| 140 |
+
"""Test that context manager yields None."""
|
| 141 |
+
with temp_env_var("TEST_VAR", "value") as result:
|
| 142 |
+
assert result is None
|
| 143 |
+
assert os.environ["TEST_VAR"] == "value"
|
code/RL_model/verl/verl_train/tests/utils/test_timeout_decorator_cpu.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import multiprocessing
|
| 16 |
+
import sys
|
| 17 |
+
import threading
|
| 18 |
+
import time
|
| 19 |
+
|
| 20 |
+
import pytest # Import pytest
|
| 21 |
+
|
| 22 |
+
from verl.utils.py_functional import timeout_limit as timeout
|
| 23 |
+
|
| 24 |
+
# --- Test Task Functions ---
|
| 25 |
+
TEST_TIMEOUT_SECONDS = 1.5 # Timeout duration for tests
|
| 26 |
+
LONG_TASK_DURATION = TEST_TIMEOUT_SECONDS + 0.5 # Duration slightly longer than timeout
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@timeout(seconds=TEST_TIMEOUT_SECONDS) # Keep global decorator for mp tests
|
| 30 |
+
def quick_task(x):
|
| 31 |
+
"""A task that completes quickly."""
|
| 32 |
+
time.sleep(0.1)
|
| 33 |
+
return "quick_ok"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@timeout(seconds=TEST_TIMEOUT_SECONDS) # Keep global decorator for mp tests
|
| 37 |
+
def slow_task(x):
|
| 38 |
+
"""A task that takes longer than the timeout."""
|
| 39 |
+
time.sleep(LONG_TASK_DURATION)
|
| 40 |
+
return "slow_finished" # This return value indicates it didn't time out
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# REMOVE global decorator here
|
| 44 |
+
def task_raises_value_error(): # Now truly not globally decorated
|
| 45 |
+
"""A task that intentionally raises a ValueError."""
|
| 46 |
+
raise ValueError("Specific value error from task")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# --- Top-level function for signal test in subprocess ---
|
| 50 |
+
# Keep this decorated globally for the specific subprocess test case
|
| 51 |
+
@timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)
|
| 52 |
+
def top_level_decorated_quick_task_signal():
|
| 53 |
+
"""A pickleable top-level function decorated with signal timeout."""
|
| 54 |
+
# Assuming this calls the logic of quick_task directly for the test purpose
|
| 55 |
+
time.sleep(0.1)
|
| 56 |
+
return "quick_ok_signal_subprocess" # Different return for clarity if needed
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# --- Top-level function for signal test in subprocess ---
|
| 60 |
+
# Keep this decorated globally for the specific subprocess test case
|
| 61 |
+
@timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)
|
| 62 |
+
def top_level_decorated_slow_task_signal():
|
| 63 |
+
"""A pickleable top-level function decorated with signal timeout."""
|
| 64 |
+
time.sleep(LONG_TASK_DURATION)
|
| 65 |
+
return "slow_finished"
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# --- NEW: Top-level helper function to run target in process ---
|
| 69 |
+
def run_target_and_put_in_queue(target_func, q):
|
| 70 |
+
"""
|
| 71 |
+
Top-level helper function to run a target function and put its result or exception into a queue.
|
| 72 |
+
This function is pickleable and can be used as the target for multiprocessing.Process.
|
| 73 |
+
"""
|
| 74 |
+
try:
|
| 75 |
+
result = target_func()
|
| 76 |
+
q.put(("success", result))
|
| 77 |
+
except Exception as e:
|
| 78 |
+
q.put(("error", e))
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# Use a module-level fixture to set the start method on macOS
|
| 82 |
+
@pytest.fixture(scope="module", autouse=True) # Changed scope to module
|
| 83 |
+
def set_macos_start_method():
|
| 84 |
+
if sys.platform == "darwin":
|
| 85 |
+
# Force fork method on macOS to avoid pickling issues with globally decorated functions
|
| 86 |
+
# when running tests via pytest discovery.
|
| 87 |
+
current_method = multiprocessing.get_start_method(allow_none=True)
|
| 88 |
+
# Only set if not already set or if set to something else (less likely in test run)
|
| 89 |
+
if current_method is None or current_method != "fork":
|
| 90 |
+
try:
|
| 91 |
+
multiprocessing.set_start_method("fork", force=True)
|
| 92 |
+
except RuntimeError:
|
| 93 |
+
# Might fail if context is already started, ignore in that case.
|
| 94 |
+
pass
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def test_quick_task(): # Renamed from test_multiprocessing_quick_task
|
| 98 |
+
"""Tests timeout handles a quick task correctly."""
|
| 99 |
+
# Call the globally decorated function directly
|
| 100 |
+
result = quick_task(1)
|
| 101 |
+
assert result == "quick_ok" # Use pytest assert
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def test_slow_task_timeout(): # Renamed from test_multiprocessing_slow_task_timeout
|
| 105 |
+
"""Tests timeout correctly raises TimeoutError for a slow task."""
|
| 106 |
+
# Call the globally decorated function directly within pytest.raises
|
| 107 |
+
with pytest.raises(TimeoutError) as excinfo: # Use pytest.raises
|
| 108 |
+
slow_task(1)
|
| 109 |
+
# Check the error message from the multiprocessing implementation
|
| 110 |
+
assert f"timed out after {TEST_TIMEOUT_SECONDS} seconds" in str(excinfo.value) # Use pytest assert
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def test_internal_exception(): # Renamed from test_multiprocessing_internal_exception
|
| 114 |
+
"""Tests timeout correctly propagates internal exceptions."""
|
| 115 |
+
# Apply the default timeout decorator dynamically to the undecorated function
|
| 116 |
+
decorated_task = timeout(seconds=TEST_TIMEOUT_SECONDS)(task_raises_value_error) # Apply decorator dynamically
|
| 117 |
+
with pytest.raises(ValueError) as excinfo: # Use pytest.raises
|
| 118 |
+
decorated_task() # Call the dynamically decorated function
|
| 119 |
+
assert str(excinfo.value) == "Specific value error from task" # Use pytest assert
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# --- Test the signal implementation (use_signals=True) ---
|
| 123 |
+
# Note: As per py_functional.py, use_signals=True currently falls back to
|
| 124 |
+
# multiprocessing on POSIX. These tests verify that behavior.
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def test_signal_quick_task_main_process(): # Removed self
|
| 128 |
+
"""Tests signal timeout handles a quick task correctly in the main process."""
|
| 129 |
+
|
| 130 |
+
# Apply the signal decorator dynamically
|
| 131 |
+
def plain_quick_task_logic():
|
| 132 |
+
time.sleep(0.1)
|
| 133 |
+
return "quick_ok_signal"
|
| 134 |
+
|
| 135 |
+
decorated_task = timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)(plain_quick_task_logic)
|
| 136 |
+
assert decorated_task() == "quick_ok_signal" # Use pytest assert
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def test_signal_slow_task_main_process_timeout(): # Removed self
|
| 140 |
+
"""Tests signal timeout correctly raises TimeoutError for a slow task in the main process."""
|
| 141 |
+
|
| 142 |
+
# Apply the signal decorator dynamically
|
| 143 |
+
def plain_slow_task_logic():
|
| 144 |
+
time.sleep(LONG_TASK_DURATION)
|
| 145 |
+
return "slow_finished_signal"
|
| 146 |
+
|
| 147 |
+
decorated_task = timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)(plain_slow_task_logic)
|
| 148 |
+
with pytest.raises(TimeoutError) as excinfo: # Use pytest.raises
|
| 149 |
+
decorated_task()
|
| 150 |
+
# Check the error message (falls back to multiprocessing message on POSIX)
|
| 151 |
+
assert f"timed out after {TEST_TIMEOUT_SECONDS} seconds" in str(excinfo.value) # Use pytest assert
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
@pytest.mark.skip(reason="this test won't pass. Just to show why use_signals should not be used")
|
| 155 |
+
def test_signal_in_thread_does_not_timeout():
|
| 156 |
+
"""
|
| 157 |
+
Tests that signal-based timeout does NOT work reliably in a child thread.
|
| 158 |
+
The TimeoutError from the signal handler is not expected to be raised.
|
| 159 |
+
"""
|
| 160 |
+
result_container = [] # Use a list to store result from thread
|
| 161 |
+
exception_container = [] # Use a list to store exception from thread
|
| 162 |
+
|
| 163 |
+
@timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)
|
| 164 |
+
def slow_task_in_thread():
|
| 165 |
+
try:
|
| 166 |
+
print("Thread: Starting slow task...")
|
| 167 |
+
time.sleep(LONG_TASK_DURATION)
|
| 168 |
+
print("Thread: Slow task finished.")
|
| 169 |
+
return "slow_finished_in_thread"
|
| 170 |
+
except Exception as e:
|
| 171 |
+
# Catch any exception within the thread's target function
|
| 172 |
+
print(f"Thread: Caught exception: {e}")
|
| 173 |
+
exception_container.append(e)
|
| 174 |
+
return None # Indicate failure
|
| 175 |
+
|
| 176 |
+
def thread_target():
|
| 177 |
+
try:
|
| 178 |
+
# Run the decorated function inside the thread
|
| 179 |
+
res = slow_task_in_thread()
|
| 180 |
+
if res is not None:
|
| 181 |
+
result_container.append(res)
|
| 182 |
+
except Exception as e:
|
| 183 |
+
# This might catch exceptions happening *outside* the decorated function
|
| 184 |
+
# but still within the thread target, though less likely here.
|
| 185 |
+
print(f"Thread Target: Caught exception: {e}")
|
| 186 |
+
exception_container.append(e)
|
| 187 |
+
|
| 188 |
+
thread = threading.Thread(target=thread_target)
|
| 189 |
+
print("Main: Starting thread...")
|
| 190 |
+
thread.start()
|
| 191 |
+
# Wait longer than the timeout + task duration to ensure the thread finishes
|
| 192 |
+
# regardless of whether timeout worked or not.
|
| 193 |
+
thread.join(timeout=LONG_TASK_DURATION + 1)
|
| 194 |
+
|
| 195 |
+
assert len(exception_container) == 1
|
| 196 |
+
assert isinstance(exception_container[0], TimeoutError)
|
| 197 |
+
assert not result_container
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def test_in_thread_timeout():
|
| 201 |
+
result_container = [] # Use a list to store result from thread
|
| 202 |
+
exception_container = [] # Use a list to store exception from thread
|
| 203 |
+
|
| 204 |
+
@timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=False)
|
| 205 |
+
def slow_task_in_thread():
|
| 206 |
+
try:
|
| 207 |
+
print("Thread: Starting slow task...")
|
| 208 |
+
time.sleep(LONG_TASK_DURATION)
|
| 209 |
+
print("Thread: Slow task finished.")
|
| 210 |
+
return "slow_finished_in_thread"
|
| 211 |
+
except Exception as e:
|
| 212 |
+
# Catch any exception within the thread's target function
|
| 213 |
+
print(f"Thread: Caught exception: {e}")
|
| 214 |
+
exception_container.append(e)
|
| 215 |
+
return None # Indicate failure
|
| 216 |
+
|
| 217 |
+
def thread_target():
|
| 218 |
+
try:
|
| 219 |
+
# Run the decorated function inside the thread
|
| 220 |
+
res = slow_task_in_thread()
|
| 221 |
+
if res is not None:
|
| 222 |
+
result_container.append(res)
|
| 223 |
+
except Exception as e:
|
| 224 |
+
# This might catch exceptions happening *outside* the decorated function
|
| 225 |
+
# but still within the thread target, though less likely here.
|
| 226 |
+
print(f"Thread Target: Caught exception: {e}")
|
| 227 |
+
exception_container.append(e)
|
| 228 |
+
|
| 229 |
+
thread = threading.Thread(target=thread_target)
|
| 230 |
+
print("Main: Starting thread...")
|
| 231 |
+
thread.start()
|
| 232 |
+
# Wait longer than the timeout + task duration to ensure the thread finishes
|
| 233 |
+
# regardless of whether timeout worked or not.
|
| 234 |
+
thread.join(timeout=LONG_TASK_DURATION + 1)
|
| 235 |
+
|
| 236 |
+
assert len(exception_container) == 1
|
| 237 |
+
assert isinstance(exception_container[0], TimeoutError)
|
| 238 |
+
assert not result_container
|
code/RL_model/verl/verl_train/tests/utils/test_torch_functional.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
import pytest
|
| 18 |
+
import torch
|
| 19 |
+
import torch.distributed as dist
|
| 20 |
+
import torch.multiprocessing as mp
|
| 21 |
+
|
| 22 |
+
from verl.utils.device import get_device_name, get_nccl_backend, get_torch_device
|
| 23 |
+
from verl.utils.torch_functional import (
|
| 24 |
+
distributed_masked_mean,
|
| 25 |
+
distributed_mean_max_min_std,
|
| 26 |
+
expand_as_nested,
|
| 27 |
+
masked_mean,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _worker_mean(rank: int, world_size: int, rendezvous_file: str):
|
| 32 |
+
# 1) set GPU and init NCCL
|
| 33 |
+
get_torch_device().set_device(rank)
|
| 34 |
+
dist.init_process_group(
|
| 35 |
+
backend=get_nccl_backend(),
|
| 36 |
+
init_method=f"file://{rendezvous_file}",
|
| 37 |
+
rank=rank,
|
| 38 |
+
world_size=world_size,
|
| 39 |
+
)
|
| 40 |
+
# each rank holds tensor [rank+1]
|
| 41 |
+
local = torch.tensor([float(rank + 1)], device=f"{get_device_name()}:{rank}")
|
| 42 |
+
mean, gmax, gmin, gstd = distributed_mean_max_min_std(local, True, True, True)
|
| 43 |
+
|
| 44 |
+
values = [float(i + 1) for i in range(world_size)]
|
| 45 |
+
exp_mean = sum(values) / len(values)
|
| 46 |
+
exp_max = max(values)
|
| 47 |
+
exp_min = min(values)
|
| 48 |
+
var = sum((x - exp_mean) ** 2 for x in values) / (len(values) - 1)
|
| 49 |
+
exp_std = var**0.5
|
| 50 |
+
|
| 51 |
+
# all ranks should see the same result
|
| 52 |
+
assert torch.allclose(mean.cpu(), torch.tensor(exp_mean)), f"mean@{rank}"
|
| 53 |
+
assert torch.allclose(gmax.cpu(), torch.tensor(exp_max)), f"max@{rank}"
|
| 54 |
+
assert torch.allclose(gmin.cpu(), torch.tensor(exp_min)), f"min@{rank}"
|
| 55 |
+
assert torch.allclose(gstd.cpu(), torch.tensor(exp_std)), f"std@{rank}"
|
| 56 |
+
|
| 57 |
+
dist.destroy_process_group()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@pytest.mark.parametrize(
|
| 61 |
+
"value,mask,gt",
|
| 62 |
+
[
|
| 63 |
+
([1.0, 2.0, 3.0, 4.0], [1, 0, 0, 1], 2.5),
|
| 64 |
+
([1.0, 2.0, float("nan"), 4.0], [1, 0, 0, 1], 2.5),
|
| 65 |
+
([1.0, 2.0, float("nan"), 4.0], [1, 0, 1, 0], float("nan")),
|
| 66 |
+
],
|
| 67 |
+
)
|
| 68 |
+
def test_masked_mean(value, mask, gt):
|
| 69 |
+
res = masked_mean(torch.tensor(value), torch.tensor(mask))
|
| 70 |
+
gt = torch.tensor(gt)
|
| 71 |
+
assert torch.allclose(res, gt) or (torch.isnan(res) and torch.isnan(gt))
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
@pytest.mark.parametrize("world_size", [2, 4])
|
| 75 |
+
def test_distributed_mean_max_min_std(world_size, tmp_path):
|
| 76 |
+
rendezvous_file = str(tmp_path / "rdzv_mean")
|
| 77 |
+
os.makedirs(os.path.dirname(rendezvous_file), exist_ok=True)
|
| 78 |
+
|
| 79 |
+
mp.spawn(
|
| 80 |
+
fn=_worker_mean,
|
| 81 |
+
args=(world_size, rendezvous_file),
|
| 82 |
+
nprocs=world_size,
|
| 83 |
+
join=True,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _worker_mask(rank: int, world_size: int, rendezvous_file: str):
|
| 88 |
+
get_torch_device().set_device(rank)
|
| 89 |
+
dist.init_process_group(
|
| 90 |
+
backend=get_nccl_backend(),
|
| 91 |
+
init_method=f"file://{rendezvous_file}",
|
| 92 |
+
rank=rank,
|
| 93 |
+
world_size=world_size,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# build per‐rank tensor and mask
|
| 97 |
+
local_tensor = torch.tensor([rank * 2 + 1.0, rank * 2 + 2.0], device=f"{get_device_name()}:{rank}")
|
| 98 |
+
if rank == 0:
|
| 99 |
+
mask = torch.tensor([1, 0], device=f"{get_device_name()}:{rank}", dtype=torch.float32)
|
| 100 |
+
else:
|
| 101 |
+
mask = torch.tensor([0, 1], device=f"{get_device_name()}:{rank}", dtype=torch.float32)
|
| 102 |
+
|
| 103 |
+
gmean = distributed_masked_mean(local_tensor, mask)
|
| 104 |
+
|
| 105 |
+
valid_values = [1.0] + [2 * i + 2.0 for i in range(1, world_size)]
|
| 106 |
+
expected_mean = sum(valid_values) / len(valid_values)
|
| 107 |
+
assert torch.allclose(gmean.cpu(), torch.tensor(expected_mean)), f"masked_mean@{rank}"
|
| 108 |
+
|
| 109 |
+
dist.destroy_process_group()
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@pytest.mark.parametrize("world_size", [2, 4])
|
| 113 |
+
def test_distributed_masked_mean(world_size, tmp_path):
|
| 114 |
+
rendezvous_file = str(tmp_path / "rdzv_mask")
|
| 115 |
+
os.makedirs(os.path.dirname(rendezvous_file), exist_ok=True)
|
| 116 |
+
|
| 117 |
+
mp.spawn(
|
| 118 |
+
fn=_worker_mask,
|
| 119 |
+
args=(world_size, rendezvous_file),
|
| 120 |
+
nprocs=world_size,
|
| 121 |
+
join=True,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def test_expand_as_nested():
|
| 126 |
+
a = torch.randn(2)
|
| 127 |
+
b = torch.randn(3)
|
| 128 |
+
c = torch.randn(4)
|
| 129 |
+
nested_tensor = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
|
| 130 |
+
tensor = torch.tensor([1, 2, 3])
|
| 131 |
+
|
| 132 |
+
output = expand_as_nested(tensor, nested_tensor)
|
| 133 |
+
|
| 134 |
+
assert output.values().tolist() == [1, 1, 2, 2, 2, 3, 3, 3, 3]
|
| 135 |
+
assert torch.all(output.offsets() == nested_tensor.offsets()).item()
|
| 136 |
+
|
| 137 |
+
# test exceptions
|
| 138 |
+
with pytest.raises(AssertionError):
|
| 139 |
+
expand_as_nested(tensor, tensor)
|
| 140 |
+
|
| 141 |
+
other_tensor = torch.tensor([1, 2, 3, 4])
|
| 142 |
+
|
| 143 |
+
with pytest.raises(AssertionError):
|
| 144 |
+
expand_as_nested(other_tensor, nested_tensor)
|
| 145 |
+
|
| 146 |
+
other_tensor = torch.tensor([[1, 2, 3]])
|
| 147 |
+
|
| 148 |
+
with pytest.raises(AssertionError):
|
| 149 |
+
expand_as_nested(other_tensor, nested_tensor)
|
| 150 |
+
|
| 151 |
+
with pytest.raises(AssertionError):
|
| 152 |
+
expand_as_nested(tensor, nested_tensor.unsqueeze(-1))
|