shahidul034 commited on
Commit
7b53b83
·
verified ·
1 Parent(s): c3dfccb

Add files using upload-large-folder tool

Browse files
Files changed (33) hide show
  1. code/RL_model/verl/verl_train/tests/single_controller/__init__.py +13 -0
  2. code/RL_model/verl/verl_train/tests/single_controller/test_auto_padding_on_cpu.py +152 -0
  3. code/RL_model/verl/verl_train/tests/single_controller/test_colocated_workers.py +86 -0
  4. code/RL_model/verl/verl_train/tests/single_controller/test_colocated_workers_fused.py +86 -0
  5. code/RL_model/verl/verl_train/tests/single_controller/test_data_transfer.py +109 -0
  6. code/RL_model/verl/verl_train/tests/single_controller/test_decorator_on_cpu.py +200 -0
  7. code/RL_model/verl/verl_train/tests/single_controller/test_device_mesh_register.py +158 -0
  8. code/RL_model/verl/verl_train/tests/single_controller/test_driverfunc_to_worker.py +85 -0
  9. code/RL_model/verl/verl_train/tests/single_controller/test_fused_workers_on_cpu.py +90 -0
  10. code/RL_model/verl/verl_train/tests/single_controller/test_high_level_scheduling_api.py +103 -0
  11. code/RL_model/verl/verl_train/tests/single_controller/test_rvdz.py +51 -0
  12. code/RL_model/verl/verl_train/tests/single_controller/test_worker_group_torch.py +116 -0
  13. code/RL_model/verl/verl_train/tests/special_e2e/README.md +1 -0
  14. code/RL_model/verl/verl_train/tests/utils/test_activation_offload.py +175 -0
  15. code/RL_model/verl/verl_train/tests/utils/test_check_ipc_version_support_on_npu.py +231 -0
  16. code/RL_model/verl/verl_train/tests/utils/test_config_on_cpu.py +97 -0
  17. code/RL_model/verl/verl_train/tests/utils/test_flops_counter.py +480 -0
  18. code/RL_model/verl/verl_train/tests/utils/test_fs_on_cpu.py +94 -0
  19. code/RL_model/verl/verl_train/tests/utils/test_groupwise.py +98 -0
  20. code/RL_model/verl/verl_train/tests/utils/test_import_utils_on_cpu.py +97 -0
  21. code/RL_model/verl/verl_train/tests/utils/test_linear_cross_entropy.py +361 -0
  22. code/RL_model/verl/verl_train/tests/utils/test_mlflow_key_sanitization.py +64 -0
  23. code/RL_model/verl/verl_train/tests/utils/test_model_on_cpu.py +52 -0
  24. code/RL_model/verl/verl_train/tests/utils/test_nvtx_profile.py +168 -0
  25. code/RL_model/verl/verl_train/tests/utils/test_rollout_skip_on_cpu.py +142 -0
  26. code/RL_model/verl/verl_train/tests/utils/test_rollout_trace_on_cpu.py +246 -0
  27. code/RL_model/verl/verl_train/tests/utils/test_seqlen_balancing.py +278 -0
  28. code/RL_model/verl/verl_train/tests/utils/test_shared_memory.py +260 -0
  29. code/RL_model/verl/verl_train/tests/utils/test_special_linear_cross_entropy_tp.py +514 -0
  30. code/RL_model/verl/verl_train/tests/utils/test_special_mstx_profile.py +274 -0
  31. code/RL_model/verl/verl_train/tests/utils/test_temp_env_on_cpu.py +143 -0
  32. code/RL_model/verl/verl_train/tests/utils/test_timeout_decorator_cpu.py +238 -0
  33. code/RL_model/verl/verl_train/tests/utils/test_torch_functional.py +152 -0
code/RL_model/verl/verl_train/tests/single_controller/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
code/RL_model/verl/verl_train/tests/single_controller/test_auto_padding_on_cpu.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ import ray
17
+ import torch
18
+
19
+ from verl import DataProto
20
+ from verl.protocol import DataProtoConfig
21
+ from verl.single_controller.base import Worker
22
+ from verl.single_controller.base.decorator import Dispatch, register
23
+ from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
24
+
25
+ # or set env var VERL_AUTO_PADDING = "1" / "true"
26
+ DataProtoConfig.auto_padding = True
27
+
28
+
29
+ @ray.remote
30
+ class Actor(Worker):
31
+ def __init__(self) -> None:
32
+ super().__init__()
33
+
34
+ @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
35
+ def add(self, data: DataProto):
36
+ data.batch["a"] += self.rank
37
+ return data
38
+
39
+
40
+ def test_auto_padding():
41
+ ray.init(num_cpus=100)
42
+
43
+ chunk_size = 4
44
+ actor_cls = RayClassWithInitArgs(cls=Actor)
45
+ resource_pool = RayResourcePool(process_on_nodes=[chunk_size], use_gpu=False)
46
+ actor_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=actor_cls)
47
+
48
+ # test locally first
49
+ for test_size in range(4, 20):
50
+ local_data = DataProto.from_dict({"a": torch.zeros(test_size)}, {"na": np.zeros(test_size, dtype=object)})
51
+ # print(f"before padding, local_data = {local_data}")
52
+ padding_size = (chunk_size - (test_size % chunk_size)) if (test_size % chunk_size > 0) else 0
53
+ local_data.padding(padding_size)
54
+ # print(f"after padding, local_data = {local_data}")
55
+ assert len(local_data) == len(local_data) + len(local_data) % chunk_size, (
56
+ f"expecting padded length to be {len(local_data) + len(local_data) % chunk_size}, but got {len(local_data)}"
57
+ )
58
+ chunked = local_data.chunk(chunk_size)
59
+ assert len(chunked) == chunk_size, f"during test_size = {test_size}, expecting {chunk_size}, got {chunked}"
60
+ for dp in chunked:
61
+ assert len(dp) == test_size // chunk_size + bool(test_size % chunk_size), (
62
+ f"test size = {test_size}, expecting dp to be length of "
63
+ f"{test_size // chunk_size + bool(test_size % chunk_size)}, but got {len(dp)}: {dp} {chunked}"
64
+ )
65
+
66
+ # test with RayWorkerGroup method decorated as dispatch_mode=Dispatch.DP_COMPUTE_PROTO
67
+ data = DataProto.from_dict({"a": torch.zeros(10)}, {"na": np.array([str(i) for i in range(10)], dtype=object)})
68
+ output = actor_wg.add(data)
69
+
70
+ print(output.batch["a"])
71
+ assert len(output) == 10, "Failed in args split and padding."
72
+
73
+ data = DataProto.from_dict({"a": torch.zeros(10)}, {"na": np.array([str(i) for i in range(10)], dtype=object)})
74
+ output = actor_wg.add(data=data)
75
+
76
+ print(output.batch["a"])
77
+ assert len(output) == 10, "Failed in kwargs split and padding."
78
+
79
+ data = DataProto.from_dict({"a": torch.zeros(1)}, {"na": np.array([str(i) for i in range(1)], dtype=object)})
80
+ output = actor_wg.add(data)
81
+
82
+ print(output.batch["a"])
83
+ assert len(output) == 1, "Failed in args split and padding."
84
+
85
+ data = DataProto.from_dict({"a": torch.zeros(1)}, {"na": np.array([str(i) for i in range(1)], dtype=object)})
86
+ output = actor_wg.add(data=data)
87
+
88
+ print(output.batch["a"])
89
+ assert len(output) == 1, "Failed in kwargs split and padding."
90
+
91
+ data = DataProto.from_dict({"a": torch.zeros(8)}, {"na": np.array([str(i) for i in range(8)], dtype=object)})
92
+ output = actor_wg.add(data)
93
+
94
+ print(output.batch["a"])
95
+ assert len(output) == 8, "Failed in args split and padding."
96
+
97
+ data = DataProto.from_dict({"a": torch.zeros(8)}, {"na": np.array([str(i) for i in range(8)], dtype=object)})
98
+ output = actor_wg.add(data=data)
99
+
100
+ print(output.batch["a"])
101
+ assert len(output) == 8, "Failed in kwargs split and padding."
102
+
103
+ # test data proto specific config
104
+ DataProtoConfig.auto_padding = False
105
+
106
+ data = DataProto.from_dict(
107
+ {"a": torch.zeros(10)}, {"na": np.array([str(i) for i in range(10)], dtype=object)}, auto_padding=True
108
+ )
109
+ output = actor_wg.add(data)
110
+ print(output.batch["a"])
111
+ assert len(output) == 10, "Failed in args split and padding."
112
+
113
+ data = DataProto.from_dict(
114
+ {"a": torch.zeros(10)}, {"na": np.array([str(i) for i in range(10)], dtype=object)}, auto_padding=True
115
+ )
116
+ output = actor_wg.add(data=data)
117
+ print(output.batch["a"])
118
+ assert len(output) == 10, "Failed in kwargs split and padding."
119
+
120
+ data = DataProto.from_single_dict(
121
+ {"a": torch.zeros(1), "na": np.array([str(i) for i in range(1)], dtype=object)}, auto_padding=True
122
+ )
123
+ output = actor_wg.add(data)
124
+
125
+ print(output.batch["a"])
126
+ assert len(output) == 1, "Failed in args split and padding."
127
+
128
+ data = DataProto.from_single_dict(
129
+ {"a": torch.zeros(1), "na": np.array([str(i) for i in range(1)], dtype=object)}, auto_padding=True
130
+ )
131
+ output = actor_wg.add(data=data)
132
+
133
+ print(output.batch["a"])
134
+ assert len(output) == 1, "Failed in kwargs split and padding."
135
+
136
+ data = DataProto.from_single_dict({"a": torch.zeros(8), "na": np.array([str(i) for i in range(8)], dtype=object)})
137
+ output = actor_wg.add(data)
138
+
139
+ print(output.batch["a"])
140
+ assert len(output) == 8, "Failed in args split and padding."
141
+
142
+ data = DataProto.from_single_dict({"a": torch.zeros(8), "na": np.array([str(i) for i in range(8)], dtype=object)})
143
+ output = actor_wg.add(data=data)
144
+
145
+ print(output.batch["a"])
146
+ assert len(output) == 8, "Failed in kwargs split and padding."
147
+
148
+ ray.shutdown()
149
+
150
+
151
+ if __name__ == "__main__":
152
+ test_auto_padding()
code/RL_model/verl/verl_train/tests/single_controller/test_colocated_workers.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import ray
16
+
17
+ from verl import DataProto
18
+ from verl.single_controller.base import Worker
19
+ from verl.single_controller.base.decorator import Dispatch, register
20
+ from verl.single_controller.ray.base import (
21
+ RayClassWithInitArgs,
22
+ RayResourcePool,
23
+ RayWorkerGroup,
24
+ create_colocated_worker_cls,
25
+ )
26
+ from verl.utils.device import get_device_name
27
+
28
+
29
+ @ray.remote
30
+ class Actor(Worker):
31
+ def __init__(self) -> None:
32
+ super().__init__()
33
+
34
+ @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
35
+ def add(self, data: DataProto):
36
+ data.batch["a"] += self.rank
37
+ return data
38
+
39
+
40
+ @ray.remote
41
+ class Critic(Worker):
42
+ def __init__(self, config) -> None:
43
+ super().__init__()
44
+ self.config = config
45
+
46
+ @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
47
+ async def sub(self, data: DataProto):
48
+ data.batch["a"] -= self.config["b"]
49
+ return data
50
+
51
+
52
+ def test_colocated_workers():
53
+ ray.init()
54
+
55
+ import torch
56
+
57
+ data = DataProto.from_dict({"a": torch.zeros(10)})
58
+ # create separate workers on the same resource pool
59
+ actor_cls = RayClassWithInitArgs(cls=Actor)
60
+ critic_cls = RayClassWithInitArgs(cls=Critic, config={"b": 10})
61
+ resource_pool = RayResourcePool(process_on_nodes=[2])
62
+
63
+ actor_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=actor_cls, device_name=get_device_name())
64
+ critic_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=critic_cls, device_name=get_device_name())
65
+
66
+ expected_actor_output = actor_wg.add(data)
67
+ expected_critic_output = critic_wg.sub(data)
68
+
69
+ # create colocated workers
70
+ cls_dict = {"actor": actor_cls, "critic": critic_cls}
71
+ ray_cls_with_init = create_colocated_worker_cls(cls_dict)
72
+ wg_dict = RayWorkerGroup(
73
+ resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, device_name=get_device_name()
74
+ )
75
+ spawn_wg = wg_dict.spawn(prefix_set=cls_dict.keys())
76
+
77
+ colocated_actor_wg = spawn_wg["actor"]
78
+ colocated_critic_wg = spawn_wg["critic"]
79
+
80
+ actor_output = colocated_actor_wg.add(data)
81
+ critic_output = colocated_critic_wg.sub(data)
82
+
83
+ torch.testing.assert_close(expected_actor_output.batch, actor_output.batch, atol=0, rtol=0)
84
+ torch.testing.assert_close(expected_critic_output.batch, critic_output.batch, atol=0, rtol=0)
85
+
86
+ ray.shutdown()
code/RL_model/verl/verl_train/tests/single_controller/test_colocated_workers_fused.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import ray
16
+
17
+ from verl import DataProto
18
+ from verl.single_controller.base import Worker
19
+ from verl.single_controller.base.decorator import Dispatch, register
20
+ from verl.single_controller.ray.base import (
21
+ RayClassWithInitArgs,
22
+ RayResourcePool,
23
+ RayWorkerGroup,
24
+ create_colocated_worker_cls_fused,
25
+ )
26
+ from verl.utils.device import get_device_name
27
+
28
+
29
+ @ray.remote
30
+ class Actor(Worker):
31
+ def __init__(self) -> None:
32
+ super().__init__()
33
+
34
+ @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
35
+ def add(self, data: DataProto):
36
+ data.batch["a"] += self.rank
37
+ return data
38
+
39
+
40
+ @ray.remote
41
+ class Critic(Worker):
42
+ def __init__(self, config) -> None:
43
+ super().__init__()
44
+ self.config = config
45
+
46
+ @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
47
+ def sub(self, data: DataProto):
48
+ data.batch["a"] -= self.config["b"]
49
+ return data
50
+
51
+
52
+ def test_colocated_workers_fused():
53
+ ray.init()
54
+
55
+ import torch
56
+
57
+ data = DataProto.from_dict({"a": torch.zeros(10)})
58
+ # create separate workers on the same resource pool
59
+ actor_cls = RayClassWithInitArgs(cls=Actor)
60
+ critic_cls = RayClassWithInitArgs(cls=Critic, config={"b": 10})
61
+ resource_pool = RayResourcePool(process_on_nodes=[2])
62
+
63
+ actor_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=actor_cls, device_name=get_device_name())
64
+ critic_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=critic_cls, device_name=get_device_name())
65
+
66
+ expected_actor_output = actor_wg.add(data)
67
+ expected_critic_output = critic_wg.sub(data)
68
+
69
+ # create colocated workers
70
+ cls_dict = {"actor": actor_cls, "critic": critic_cls}
71
+ ray_cls_with_init = create_colocated_worker_cls_fused(cls_dict)
72
+ wg_dict = RayWorkerGroup(
73
+ resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, device_name=get_device_name()
74
+ )
75
+ spawn_wg = wg_dict.spawn(prefix_set=cls_dict.keys())
76
+
77
+ colocated_actor_wg = spawn_wg["actor"]
78
+ colocated_critic_wg = spawn_wg["critic"]
79
+
80
+ actor_output = colocated_actor_wg.add(data)
81
+ critic_output = colocated_critic_wg.sub(data)
82
+
83
+ torch.testing.assert_close(expected_actor_output.batch, actor_output.batch, atol=0, rtol=0)
84
+ torch.testing.assert_close(expected_critic_output.batch, critic_output.batch, atol=0, rtol=0)
85
+
86
+ ray.shutdown()
code/RL_model/verl/verl_train/tests/single_controller/test_data_transfer.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ In this test, we instantiate a data parallel worker with 8 GPUs
16
+ """
17
+
18
+ import ray
19
+ import tensordict
20
+ import torch
21
+ from codetiming import Timer
22
+ from packaging import version
23
+ from torch import distributed as dist
24
+
25
+ from verl import DataProto
26
+ from verl.single_controller.base import Worker
27
+ from verl.single_controller.base.decorator import Dispatch, register
28
+ from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
29
+ from verl.utils.device import get_device_name
30
+ from verl.utils.ray_utils import parallel_put
31
+
32
+
33
+ @ray.remote
34
+ class DummyWorker(Worker):
35
+ def __init__(self):
36
+ super().__init__()
37
+ dist.init_process_group()
38
+
39
+ @register(dispatch_mode=Dispatch.DP_COMPUTE, blocking=False)
40
+ def do_nothing(self, data):
41
+ for key in data.batch.keys():
42
+ data.batch[key] += 1
43
+ if version.parse(tensordict.__version__) >= version.parse("0.5.0"):
44
+ data.batch = data.batch.consolidate()
45
+ return data
46
+
47
+
48
+ def test_data_transfer():
49
+ ray.init()
50
+ # construct resource pool
51
+ resource_pool = RayResourcePool([8])
52
+ cls_with_init = RayClassWithInitArgs(cls=DummyWorker)
53
+ # construct worker group
54
+ wg = RayWorkerGroup(resource_pool, cls_with_init, device_name=get_device_name())
55
+
56
+ # this is real dataset size
57
+ batch_size = 4096
58
+ seqlen = 32768
59
+
60
+ data_dict = {}
61
+
62
+ for i in range(2):
63
+ data_dict[str(i)] = torch.randint(0, 10000, (batch_size, seqlen))
64
+
65
+ data = DataProto.from_dict(tensors=data_dict)
66
+
67
+ print(data)
68
+
69
+ # we manually split data here and send to each worker
70
+ data_list = data.chunk(wg.world_size)
71
+
72
+ for i in range(wg.world_size):
73
+ # consolidate is necessary
74
+ if version.parse(tensordict.__version__) >= version.parse("0.5.0"):
75
+ data_list[i].batch = data_list[i].batch.consolidate()
76
+
77
+ with Timer(name="ray.pickle", initial_text=True):
78
+ for i in range(wg.world_size):
79
+ ray.cloudpickle.pickle.dumps(data_list[i])
80
+
81
+ with Timer(name="raw.pickle", initial_text=True):
82
+ import pickle
83
+
84
+ for i in range(wg.world_size):
85
+ pickle.dumps(data_list[i])
86
+
87
+ # we put in advance
88
+ with Timer(name="put", initial_text=True):
89
+ # takes around 40 seconds
90
+ data_list_ref = parallel_put(data_list)
91
+ # for i in range(wg.world_size):
92
+ # data_list[i] = ray.put(data_list[i])
93
+
94
+ with Timer(name="launch", initial_text=True):
95
+ output_ref = wg.do_nothing(data_list_ref)
96
+
97
+ with Timer(name="get", initial_text=True):
98
+ # takes around 40 seconds
99
+ output_lst = ray.get(output_ref)
100
+
101
+ for input_data, output_data in zip(data_list, output_lst, strict=True):
102
+ for key in input_data.batch.keys():
103
+ assert torch.all(torch.eq(input_data.batch[key] + 1, output_data.batch[key])), (
104
+ input_data.batch[key],
105
+ output_data.batch[key],
106
+ key,
107
+ )
108
+
109
+ ray.shutdown()
code/RL_model/verl/verl_train/tests/single_controller/test_decorator_on_cpu.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import asyncio
16
+ import time
17
+
18
+ import pytest
19
+ import ray
20
+ import torch
21
+ from tensordict import TensorDict
22
+
23
+ from verl.protocol import DataProto, DataProtoFuture
24
+ from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register
25
+ from verl.single_controller.base.worker import Worker
26
+ from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
27
+ from verl.utils import tensordict_utils as tu
28
+
29
+
30
+ # Pytest fixture for Ray setup/teardown
31
+ @pytest.fixture
32
+ def ray_init_shutdown():
33
+ ray.init(num_cpus=100)
34
+ yield
35
+ ray.shutdown()
36
+
37
+
38
+ # Define a simple worker for testing
39
+ @ray.remote
40
+ class DecoratorTestWorker(Worker):
41
+ def __init__(self, initial_value=0):
42
+ super().__init__()
43
+ self.value = initial_value
44
+ # Simulate some setup if needed
45
+ time.sleep(0.1) # Ensure worker init completes
46
+
47
+ self._register_dispatch_collect_info(mesh_name="train", dp_rank=self.rank, is_collect=True)
48
+
49
+ # Test method for synchronous DP compute (default behavior)
50
+ @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
51
+ def dp_compute(self, data: DataProto) -> DataProto:
52
+ time.sleep(0.1) # Simulate work
53
+ rank_value = torch.tensor(self.rank, device=data.batch["input"].device, dtype=data.batch["input"].dtype)
54
+ data.batch["output"] = data.batch["input"] + self.value + rank_value
55
+ return data
56
+
57
+ # Test async def method with DP compute (default behavior)
58
+ @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO, blocking=False)
59
+ async def async_dp_compute(self, data: DataProto) -> DataProto:
60
+ # Simulate async work
61
+ await asyncio.sleep(0.1) # Simulate async work
62
+ rank_value = torch.tensor(self.rank, device=data.batch["input"].device, dtype=data.batch["input"].dtype)
63
+ data.batch["output_async"] = data.batch["input"] * 2 + self.value + rank_value
64
+ return data
65
+
66
+ @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="train"), blocking=False)
67
+ def dp_compute_td(self, data: TensorDict) -> TensorDict:
68
+ # note that we have to call contiguous so that we can modify data in plac
69
+ data = tu.contiguous(data)
70
+ rank_value = torch.tensor(self.rank, device=data["input"].device, dtype=data["input"].dtype)
71
+ data["output"] = data["input"] + self.value + rank_value
72
+ position_ids = data.pop("position_ids")
73
+ position_ids._ragged_idx = 2
74
+
75
+ for i, position_id in enumerate(position_ids.unbind(dim=0)):
76
+ assert (position_id == torch.arange(4 + rank_value * 2 + i).expand(position_id.shape)).all()
77
+
78
+ return data
79
+
80
+
81
+ # Test function for synchronous DP compute
82
+ def test_decorator_dp_compute(ray_init_shutdown):
83
+ """
84
+ Tests the default behavior of a synchronous decorated method with DP_COMPUTE_PROTO.
85
+ Verifies the result correctness.
86
+ """
87
+ num_workers = 2
88
+ resource_pool = RayResourcePool([num_workers], use_gpu=False, max_colocate_count=1) # Use CPU for simplicity
89
+ cls_with_args = RayClassWithInitArgs(cls=DecoratorTestWorker, initial_value=10)
90
+ worker_group = RayWorkerGroup(
91
+ resource_pool, cls_with_args, name_prefix=f"decorator_test_sync_dp_{int(time.time())}"
92
+ )
93
+
94
+ # Prepare input data (size 4, for 2 workers)
95
+ input_tensor = torch.arange(4, dtype=torch.float32)
96
+ data = DataProto(batch=TensorDict({"input": input_tensor}, batch_size=[4]))
97
+
98
+ # Call the decorated method
99
+ output = worker_group.dp_compute(data)
100
+
101
+ # Assert the result correctness
102
+ assert isinstance(output, DataProto), "Expected DataProto result"
103
+ assert "output" in output.batch.keys()
104
+ assert len(output) == len(data), "Output length should match input length"
105
+
106
+ # Expected output calculation for DP_COMPUTE_PROTO with 2 workers
107
+ # Worker 0 gets data[0:2], Worker 1 gets data[2:4]
108
+ # Worker 0 adds initial_value(10) + rank(0) = 10
109
+ # Worker 1 adds initial_value(10) + rank(1) = 11
110
+ expected_output_part1 = torch.tensor([0, 1], dtype=torch.float32) + 10 + 0
111
+ expected_output_part2 = torch.tensor([2, 3], dtype=torch.float32) + 10 + 1
112
+ expected_output = torch.cat([expected_output_part1, expected_output_part2])
113
+
114
+ torch.testing.assert_close(output.batch["output"], expected_output, msg="Sync DP compute output data mismatch")
115
+
116
+
117
+ # Test function for async def method with DP compute
118
+ def test_decorator_async_function(ray_init_shutdown):
119
+ """
120
+ Tests the decorator with an `async def` method using DP_COMPUTE_PROTO.
121
+ Verifies that the call returns a future and the result is correct after .get().
122
+ """
123
+ num_workers = 2
124
+ resource_pool = RayResourcePool([num_workers], use_gpu=False, max_colocate_count=1)
125
+ cls_with_args = RayClassWithInitArgs(cls=DecoratorTestWorker, initial_value=5)
126
+ worker_group = RayWorkerGroup(
127
+ resource_pool, cls_with_args, name_prefix=f"decorator_test_async_dp_{int(time.time())}"
128
+ )
129
+
130
+ # Prepare input data (size 4, for 2 workers)
131
+ input_tensor = torch.arange(4, dtype=torch.float32)
132
+ data = DataProto(batch=TensorDict({"input": input_tensor}, batch_size=[4]))
133
+
134
+ # Call the async decorated method - this should return a future
135
+ future_output: DataProtoFuture = worker_group.async_dp_compute(data)
136
+
137
+ # Assert that the call returned a future
138
+ assert isinstance(future_output, DataProtoFuture), "Expected DataProtoFuture for async def call"
139
+
140
+ # Get the result (this should block)
141
+ result_data = future_output.get()
142
+
143
+ # Assert the result correctness
144
+ assert isinstance(result_data, DataProto)
145
+ assert "output_async" in result_data.batch.keys()
146
+ assert len(result_data) == len(data), "Output length should match input length"
147
+
148
+ # Expected output calculation for DP_COMPUTE_PROTO with 2 workers
149
+ # Worker 0 gets data[0:2], Worker 1 gets data[2:4]
150
+ # Worker 0 calculates: input * 2 + initial_value(5) + rank(0)
151
+ # Worker 1 calculates: input * 2 + initial_value(5) + rank(1)
152
+ expected_output_part1 = (torch.tensor([0, 1], dtype=torch.float32) * 2) + 5 + 0
153
+ expected_output_part2 = (torch.tensor([2, 3], dtype=torch.float32) * 2) + 5 + 1
154
+ expected_output = torch.cat([expected_output_part1, expected_output_part2])
155
+
156
+ torch.testing.assert_close(
157
+ result_data.batch["output_async"], expected_output, msg="Async DP compute output data mismatch"
158
+ )
159
+
160
+
161
+ def test_decorator_dp_compute_td(ray_init_shutdown):
162
+ num_workers = 2
163
+ resource_pool = RayResourcePool([num_workers], use_gpu=False, max_colocate_count=1) # Use CPU for simplicity
164
+ cls_with_args = RayClassWithInitArgs(cls=DecoratorTestWorker, initial_value=10)
165
+ worker_group = RayWorkerGroup(
166
+ resource_pool, cls_with_args, name_prefix=f"decorator_test_sync_dp_{int(time.time())}"
167
+ )
168
+
169
+ # Prepare input data (size 4, for 2 workers)
170
+ input_tensor = torch.arange(4, dtype=torch.float32)
171
+ position_ids = torch.nested.as_nested_tensor(
172
+ [
173
+ torch.arange(4).expand(4, 4).contiguous(),
174
+ torch.arange(5).expand(4, 5).contiguous(),
175
+ torch.arange(6).expand(4, 6).contiguous(),
176
+ torch.arange(7).expand(4, 7).contiguous(),
177
+ ],
178
+ layout=torch.jagged,
179
+ )
180
+ data = TensorDict({"input": input_tensor, "position_ids": position_ids}, batch_size=[4])
181
+
182
+ # Call the decorated method
183
+ output = worker_group.dp_compute_td(data)
184
+
185
+ output = output.get()
186
+
187
+ # Assert the result correctness
188
+ assert isinstance(output, TensorDict), "Expected DataProto result"
189
+ assert "output" in output.keys()
190
+ assert len(output) == len(data), "Output length should match input length"
191
+
192
+ # Expected output calculation for DP_COMPUTE_PROTO with 2 workers
193
+ # Worker 0 gets data[0:2], Worker 1 gets data[2:4]
194
+ # Worker 0 adds initial_value(10) + rank(0) = 10
195
+ # Worker 1 adds initial_value(10) + rank(1) = 11
196
+ expected_output_part1 = torch.tensor([0, 1], dtype=torch.float32) + 10 + 0
197
+ expected_output_part2 = torch.tensor([2, 3], dtype=torch.float32) + 10 + 1
198
+ expected_output = torch.cat([expected_output_part1, expected_output_part2])
199
+
200
+ torch.testing.assert_close(output["output"], expected_output, msg="Sync DP compute output data mismatch")
code/RL_model/verl/verl_train/tests/single_controller/test_device_mesh_register.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import numpy as np
17
+ import ray
18
+ import torch
19
+ from tensordict import TensorDict
20
+
21
+ import verl.utils.tensordict_utils as tu
22
+ from verl import DataProto
23
+ from verl.single_controller.base import Worker
24
+ from verl.single_controller.base.decorator import make_nd_compute_dataproto_dispatch_fn, register
25
+ from verl.utils.device import get_device_name, get_nccl_backend
26
+
27
+
28
+ @ray.remote
29
+ class TestActor(Worker):
30
+ def __init__(self):
31
+ super().__init__()
32
+
33
+ import torch.distributed
34
+
35
+ torch.distributed.init_process_group(backend=get_nccl_backend())
36
+ self.infer_device_mesh = torch.distributed.device_mesh.init_device_mesh(
37
+ device_type=get_device_name(), mesh_shape=[2, 4], mesh_dim_names=["dp", "tp"]
38
+ )
39
+ self.train_device_mesh = torch.distributed.device_mesh.init_device_mesh(
40
+ device_type=get_device_name(), mesh_shape=[2, 2, 2], mesh_dim_names=["pp", "dp", "tp"]
41
+ )
42
+
43
+ self._register_dispatch_collect_info(
44
+ "infer",
45
+ dp_rank=self.infer_device_mesh["dp"].get_local_rank(),
46
+ is_collect=self.infer_device_mesh["tp"].get_local_rank() == 0,
47
+ )
48
+ self._register_dispatch_collect_info(
49
+ "train",
50
+ dp_rank=self.train_device_mesh["dp"].get_local_rank(),
51
+ is_collect=self.train_device_mesh["tp"].get_local_rank() == 0
52
+ and self.train_device_mesh["pp"].get_local_rank() == 1,
53
+ )
54
+
55
+ @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="infer"))
56
+ def generate_data_proto(self, data: DataProto):
57
+ tp_rank = self.infer_device_mesh["tp"].get_local_rank()
58
+ dp_rank = self.infer_device_mesh["dp"].get_local_rank()
59
+ data.batch["a"] += (tp_rank + 1) * dp_rank
60
+ return data
61
+
62
+ @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="infer"))
63
+ def generate_tensordict(self, data: TensorDict):
64
+ tp_rank = self.infer_device_mesh["tp"].get_local_rank()
65
+ dp_rank = self.infer_device_mesh["dp"].get_local_rank()
66
+ data["a"] += (tp_rank + 1) * dp_rank
67
+ return data
68
+
69
+ @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="train"))
70
+ def train_data_proto(self, data: DataProto):
71
+ tp_rank = self.train_device_mesh["tp"].get_local_rank()
72
+ dp_rank = self.train_device_mesh["dp"].get_local_rank()
73
+ pp_rank = self.train_device_mesh["pp"].get_local_rank()
74
+ data.batch["a"] += (tp_rank + 1) * (dp_rank + 2) * (pp_rank + 3)
75
+ # tp rank 0, pp rank 1, dp rank 0, output data added: 8 + 3 = 11
76
+ # tp rank 0, pp rank 1, dp rank 1, output data added: 12 + 4 = 16
77
+ return data
78
+
79
+ @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="train"))
80
+ def train_tensordict(self, data: TensorDict):
81
+ tp_rank = self.train_device_mesh["tp"].get_local_rank()
82
+ dp_rank = self.train_device_mesh["dp"].get_local_rank()
83
+ pp_rank = self.train_device_mesh["pp"].get_local_rank()
84
+ data["a"] += (tp_rank + 1) * (dp_rank + 2) * (pp_rank + 3)
85
+ # tp rank 0, pp rank 1, dp rank 0, output data added: 8 + 3 = 11
86
+ # tp rank 0, pp rank 1, dp rank 1, output data added: 12 + 4 = 16
87
+ return data
88
+
89
+ @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="infer"))
90
+ def generate_nested_tensor(self, data: TensorDict):
91
+ tp_rank = self.infer_device_mesh["tp"].get_local_rank()
92
+ dp_rank = self.infer_device_mesh["dp"].get_local_rank()
93
+ assert data.shape[0] == 8
94
+ data["input_ids"] += tp_rank + dp_rank
95
+
96
+ print(data)
97
+ return data
98
+
99
+
100
+ def test_dist_global_info_wg():
101
+ # create a worker group with size 8
102
+ # register a infer dist info with tp=4, dp=2
103
+ # register a train dist info with tp=2, dp=2, pp=2
104
+ # test the correctness of data dispatch and computation
105
+ from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
106
+
107
+ ray.init()
108
+
109
+ ray_cls = RayClassWithInitArgs(TestActor)
110
+ resource_pool = RayResourcePool(process_on_nodes=[8])
111
+ wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls, device_name=get_device_name())
112
+
113
+ infer_input_data_proto = DataProto.from_single_dict(data={"a": torch.tensor([1, 2])})
114
+ infer_output_data_proto = wg.generate_data_proto(infer_input_data_proto)
115
+
116
+ assert wg._dispatch_info["infer"] == [0, 0, 0, 0, 1, 1, 1, 1]
117
+
118
+ assert torch.all(torch.eq(infer_output_data_proto.batch["a"], torch.tensor([1, 3])))
119
+
120
+ infer_input_tensordict = infer_input_data_proto.to_tensordict()
121
+ infer_output_tensordict = wg.generate_tensordict(infer_input_tensordict)
122
+ assert torch.all(torch.eq(infer_output_tensordict["a"], torch.tensor([1, 3])))
123
+
124
+ train_input_data_proto = DataProto.from_single_dict(data={"a": torch.tensor([3, 4])})
125
+ train_output_data_proto = wg.train_data_proto(train_input_data_proto)
126
+
127
+ assert wg._dispatch_info["train"] == [0, 0, 1, 1, 0, 0, 1, 1]
128
+
129
+ assert torch.all(torch.eq(train_output_data_proto.batch["a"], torch.tensor([11, 16])))
130
+
131
+ train_input_tensordict = train_input_data_proto.to_tensordict()
132
+ train_output_tensordict = wg.train_tensordict(train_input_tensordict)
133
+ assert torch.all(torch.eq(train_output_tensordict["a"], torch.tensor([11, 16])))
134
+
135
+ # create a batch size of input_ids
136
+ input_ids = [
137
+ torch.randint(low=0, high=128, size=(np.random.randint(low=1, high=10, dtype=np.int64),)) for _ in range(16)
138
+ ]
139
+ input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged)
140
+ data = tu.get_tensordict(tensor_dict={"input_ids": input_ids})
141
+ output = wg.generate_nested_tensor(data)
142
+
143
+ input_ids_chunked = list(input_ids.chunk(2))
144
+
145
+ print(input_ids_chunked)
146
+
147
+ input_ids_chunked[0] += 0
148
+ input_ids_chunked[1] += 1
149
+
150
+ expected = tu.concat_nested_tensors(input_ids_chunked)
151
+
152
+ assert torch.all(torch.eq(output["input_ids"].values(), expected.values()))
153
+
154
+ ray.shutdown()
155
+
156
+
157
+ if __name__ == "__main__":
158
+ test_dist_global_info_wg()
code/RL_model/verl/verl_train/tests/single_controller/test_driverfunc_to_worker.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+
17
+ import ray
18
+ import torch
19
+ from tensordict import TensorDict
20
+
21
+ from verl import DataProto
22
+ from verl.single_controller.base.worker import Worker
23
+ from verl.single_controller.ray import RayWorkerGroup
24
+ from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool
25
+ from verl.utils.device import get_device_name
26
+
27
+ os.environ["RAY_DEDUP_LOGS"] = "0"
28
+ os.environ["NCCL_DEBUG"] = "WARN"
29
+
30
+
31
+ @ray.remote
32
+ class ModelActor(Worker):
33
+ def __init__(self):
34
+ pass
35
+
36
+
37
+ class HackSelf:
38
+ def __init__(self):
39
+ pass
40
+
41
+
42
+ def get_aux_metrics(self, test_proto):
43
+ sequence_ids = test_proto.batch["sequence_ids"]
44
+ decode_count = []
45
+ for i in range(sequence_ids.size(0)):
46
+ decode_count.append(len(sequence_ids[i].tolist()))
47
+ ret_proto = DataProto(
48
+ batch=TensorDict(
49
+ {"sequence_ids": sequence_ids, "decode_count": torch.tensor(decode_count)}, batch_size=sequence_ids.size(0)
50
+ )
51
+ )
52
+ return ret_proto
53
+
54
+
55
+ def test():
56
+ # construct model
57
+ ray.init()
58
+
59
+ # create 2 workers, each hold a GPU
60
+ resource_pool = RayResourcePool([2], use_gpu=True, name_prefix="a")
61
+
62
+ class_with_args = RayClassWithInitArgs(cls=ModelActor)
63
+ shard_wg = RayWorkerGroup(resource_pool, class_with_args, device_name=get_device_name())
64
+
65
+ test_bs = 8
66
+ test_proto = DataProto(
67
+ TensorDict(
68
+ {
69
+ "sequence_ids": torch.ones([test_bs, 2048], dtype=torch.int64),
70
+ },
71
+ batch_size=test_bs,
72
+ ),
73
+ meta_info={"query_length": 1536},
74
+ )
75
+
76
+ # Sharding among different ranks
77
+ ret_proto1 = shard_wg.execute_with_func_generator(get_aux_metrics, test_proto)
78
+
79
+ # compare execute on driver
80
+ hs = HackSelf()
81
+ ret_proto2 = get_aux_metrics(hs, test_proto)
82
+
83
+ torch.testing.assert_close(ret_proto1.batch["decode_count"], ret_proto2.batch["decode_count"])
84
+
85
+ ray.shutdown()
code/RL_model/verl/verl_train/tests/single_controller/test_fused_workers_on_cpu.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import ray
16
+
17
+ from verl.single_controller.base import Worker
18
+ from verl.single_controller.base.decorator import Dispatch, register
19
+ from verl.single_controller.ray.base import (
20
+ RayClassWithInitArgs,
21
+ RayResourcePool,
22
+ RayWorkerGroup,
23
+ create_colocated_worker_raw_cls,
24
+ )
25
+
26
+
27
+ @ray.remote
28
+ class Actor(Worker):
29
+ def __init__(self) -> None:
30
+ super().__init__()
31
+
32
+ @register(dispatch_mode=Dispatch.ONE_TO_ALL)
33
+ def add(self, x):
34
+ x += self.rank
35
+ return x
36
+
37
+
38
+ @ray.remote
39
+ class Critic(Worker):
40
+ def __init__(self, val) -> None:
41
+ super().__init__()
42
+ self.val = val
43
+
44
+ @register(dispatch_mode=Dispatch.ALL_TO_ALL)
45
+ def sub(self, x):
46
+ x -= self.val
47
+ return x
48
+
49
+
50
+ actor_cls = RayClassWithInitArgs(cls=Actor)
51
+ critic_cls = RayClassWithInitArgs(cls=Critic, val=10)
52
+ cls_dict = {"actor": actor_cls, "critic": critic_cls}
53
+ FusedBaseClass = create_colocated_worker_raw_cls(cls_dict)
54
+
55
+
56
+ @ray.remote
57
+ class HybridWorker(FusedBaseClass):
58
+ @register(dispatch_mode=Dispatch.ONE_TO_ALL)
59
+ def foo(self, x):
60
+ return self.critic.sub(self.actor.add(x))
61
+
62
+
63
+ def test_fused_workers():
64
+ ray.init(num_cpus=100)
65
+
66
+ # create separate workers on the same resource pool
67
+ process_on_nodes = [2]
68
+ resource_pool = RayResourcePool(process_on_nodes=process_on_nodes, use_gpu=False)
69
+
70
+ # create colocated workers
71
+ hybrid_cls_with_init = RayClassWithInitArgs(cls=HybridWorker)
72
+ hybrid_cls_with_init.fused_worker_used = True
73
+
74
+ fused_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=hybrid_cls_with_init)
75
+ fused_wg.fuse(cls_dict.keys())
76
+
77
+ x = fused_wg.actor.add(0.1)
78
+ print(x)
79
+ y = fused_wg.critic.sub(x)
80
+ print(y)
81
+ z = fused_wg.foo(0.1)
82
+ print(z)
83
+ for i, j in zip(y, z, strict=True):
84
+ assert i == j
85
+
86
+ ray.shutdown()
87
+
88
+
89
+ if __name__ == "__main__":
90
+ test_fused_workers()
code/RL_model/verl/verl_train/tests/single_controller/test_high_level_scheduling_api.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import gc
15
+ import time
16
+
17
+ import ray
18
+
19
+ from verl.single_controller.base.worker import Worker
20
+ from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, merge_resource_pool
21
+ from verl.utils.device import get_device_name
22
+
23
+
24
+ @ray.remote
25
+ class TestActor(Worker):
26
+ # TODO: pass *args and **kwargs is bug prone and not very convincing
27
+ def __init__(self, cuda_visible_devices=None) -> None:
28
+ super().__init__(cuda_visible_devices)
29
+
30
+ def get_node_id(self):
31
+ return ray.get_runtime_context().get_node_id()
32
+
33
+
34
+ def test():
35
+ ray.init()
36
+
37
+ # test single-node-no-partition
38
+ print("test single-node-no-partition")
39
+ resource_pool = RayResourcePool([8], use_gpu=True)
40
+
41
+ class_with_args = RayClassWithInitArgs(cls=TestActor)
42
+
43
+ print("create actor worker group")
44
+ actor_wg = RayWorkerGroup(
45
+ resource_pool, class_with_args, name_prefix="high_level_api_actor", device_name=get_device_name()
46
+ )
47
+ print("create critic worker group")
48
+ critic_wg = RayWorkerGroup(
49
+ resource_pool, class_with_args, name_prefix="hight_level_api_critic", device_name=get_device_name()
50
+ )
51
+ print("create rm worker group")
52
+ rm_wg = RayWorkerGroup(
53
+ resource_pool, class_with_args, name_prefix="high_level_api_rm", device_name=get_device_name()
54
+ )
55
+ print("create ref worker group")
56
+ ref_wg = RayWorkerGroup(
57
+ resource_pool, class_with_args, name_prefix="high_level_api_ref", device_name=get_device_name()
58
+ )
59
+
60
+ assert actor_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
61
+ assert critic_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
62
+ assert rm_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
63
+ assert ref_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
64
+
65
+ del actor_wg
66
+ del critic_wg
67
+ del rm_wg
68
+ del ref_wg
69
+ gc.collect() # make sure ray actors are deleted
70
+
71
+ [ray.util.remove_placement_group(pg) for pg in resource_pool.get_placement_groups()]
72
+ print("wait 5s to remove placemeng_group")
73
+ time.sleep(5)
74
+ # test single-node-multi-partition
75
+
76
+ print("test single-node-multi-partition")
77
+ rm_resource_pool = RayResourcePool([4], use_gpu=True, name_prefix="rm")
78
+ ref_resource_pool = RayResourcePool([4], use_gpu=True, name_prefix="ref")
79
+ total_resource_pool = merge_resource_pool(rm_resource_pool, ref_resource_pool)
80
+
81
+ assert rm_resource_pool.world_size == 4
82
+ assert ref_resource_pool.world_size == 4
83
+ assert total_resource_pool.world_size == 8
84
+
85
+ actor_wg = RayWorkerGroup(
86
+ total_resource_pool, class_with_args, name_prefix="high_level_api_actor", device_name=get_device_name()
87
+ )
88
+ critic_wg = RayWorkerGroup(
89
+ total_resource_pool, class_with_args, name_prefix="high_level_api_critic", device_name=get_device_name()
90
+ )
91
+ rm_wg = RayWorkerGroup(
92
+ rm_resource_pool, class_with_args, name_prefix="high_level_api_rm", device_name=get_device_name()
93
+ )
94
+ ref_wg = RayWorkerGroup(
95
+ ref_resource_pool, class_with_args, name_prefix="high_level_api_ref", device_name=get_device_name()
96
+ )
97
+
98
+ assert actor_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
99
+ assert critic_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
100
+ assert rm_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(4)]
101
+ assert ref_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(4, 8)]
102
+
103
+ ray.shutdown()
code/RL_model/verl/verl_train/tests/single_controller/test_rvdz.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import ray
16
+
17
+
18
+ @ray.remote
19
+ class TestWorker:
20
+ def __init__(self, rank, world_size, group_name):
21
+ self.rank = rank
22
+ self.world_size = world_size
23
+ self.group_name = group_name
24
+ self.communicator = None
25
+
26
+ def init(self):
27
+ from verl.utils.rendezvous.ray_backend import create_nccl_communicator_in_ray
28
+
29
+ self.communicator = create_nccl_communicator_in_ray(self.rank, self.world_size, self.group_name)
30
+
31
+ def test(self):
32
+ if self.communicator is None:
33
+ return None
34
+ return self.communicator.rank_id()
35
+
36
+
37
+ def test_rvdz():
38
+ ray.init()
39
+
40
+ group_name = "test_group"
41
+ world_size = 2
42
+
43
+ workers = [TestWorker.options(num_gpus=1).remote(rank, world_size, group_name) for rank in range(world_size)]
44
+
45
+ ray.get([worker.init.remote() for worker in workers])
46
+
47
+ ranks = ray.get([worker.test.remote() for worker in workers])
48
+
49
+ assert ranks == [0, 1], f"expecting [0, 1], got {ranks}"
50
+
51
+ ray.shutdown()
code/RL_model/verl/verl_train/tests/single_controller/test_worker_group_torch.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+
17
+ os.environ["RAY_DEDUP_LOGS"] = "0"
18
+ os.environ["NCCL_DEBUG"] = "WARN"
19
+
20
+ import ray
21
+ import torch
22
+ import torch.distributed
23
+
24
+ from verl.single_controller.base.worker import Worker
25
+ from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
26
+ from verl.utils.device import get_device_name
27
+
28
+
29
+ @ray.remote
30
+ class TestAllGatherActor(Worker):
31
+ def __init__(self, size) -> None:
32
+ super().__init__()
33
+ self.size = size
34
+
35
+ def init(self):
36
+ torch.distributed.init_process_group()
37
+ self.tensor = torch.zeros(size=(self.size,), dtype=torch.int64, device=get_device_name())
38
+ self.tensor += self.rank
39
+
40
+ def all_gather(self):
41
+ world_size = self._world_size
42
+ output = torch.zeros(
43
+ size=(self.tensor.shape[0] * world_size,), dtype=self.tensor.dtype, device=self.tensor.device
44
+ )
45
+ torch.distributed.all_gather_into_tensor(output, self.tensor, async_op=False)
46
+ return output
47
+
48
+
49
+ @ray.remote
50
+ class TestAllGatherActorV2(Worker):
51
+ def __init__(self, size) -> None:
52
+ super().__init__()
53
+ self.size = size
54
+
55
+ torch.distributed.init_process_group()
56
+ self.tensor = torch.zeros(size=(self.size,), dtype=torch.int64, device=get_device_name())
57
+ self.tensor += self.rank
58
+
59
+ def all_gather(self):
60
+ world_size = self._world_size
61
+ output = torch.zeros(
62
+ size=(self.tensor.shape[0] * world_size,), dtype=self.tensor.dtype, device=self.tensor.device
63
+ )
64
+ torch.distributed.all_gather_into_tensor(output, self.tensor, async_op=False)
65
+ return output
66
+
67
+
68
+ def test_all_gather_torch():
69
+ """
70
+ In this test, we instantiate 4 GPUs in a group and test the all_gather
71
+ """
72
+ ray.init()
73
+
74
+ # create 4 workers, each hold a GPU
75
+ resource_pool = RayResourcePool([4], use_gpu=True)
76
+ class_with_args = RayClassWithInitArgs(cls=TestAllGatherActor, size=2)
77
+
78
+ worker_group = RayWorkerGroup(
79
+ resource_pool, class_with_args, name_prefix="worker_group_torch", device_name=get_device_name()
80
+ )
81
+
82
+ worker_group.execute_all_sync("init")
83
+ output = worker_group.execute_all_sync("all_gather")
84
+ for i in range(1, len(output)):
85
+ assert torch.all(output[i] == output[0])
86
+
87
+ output = output[0].cpu()
88
+ print(output)
89
+ assert torch.all(output == torch.tensor([0, 0, 1, 1, 2, 2, 3, 3], dtype=torch.int64))
90
+
91
+ ray.shutdown()
92
+
93
+
94
+ def test_all_gather_torch_v2():
95
+ """
96
+ In this test, we instantiate 4 GPUs in a group and test the all_gather
97
+ """
98
+ ray.init()
99
+
100
+ # create 4 workers, each hold a GPU
101
+ resource_pool = RayResourcePool([4], use_gpu=True)
102
+ class_with_args = RayClassWithInitArgs(cls=TestAllGatherActorV2, size=2)
103
+
104
+ worker_group = RayWorkerGroup(
105
+ resource_pool, class_with_args, name_prefix="worker_group_torch", device_name=get_device_name()
106
+ )
107
+
108
+ output = worker_group.execute_all_sync("all_gather")
109
+ for i in range(1, len(output)):
110
+ assert torch.all(output[i] == output[0])
111
+
112
+ output = output[0].cpu()
113
+ print(output)
114
+ assert torch.all(output == torch.tensor([0, 0, 1, 1, 2, 2, 3, 3], dtype=torch.int64))
115
+
116
+ ray.shutdown()
code/RL_model/verl/verl_train/tests/special_e2e/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ This folder is reserved for end-to-end tests that typically require multiple GPUs.
code/RL_model/verl/verl_train/tests/utils/test_activation_offload.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import shutil
16
+ import tempfile
17
+
18
+ import pytest
19
+ import torch
20
+ import torch.distributed
21
+ import torch.multiprocessing as mp
22
+ from torch.distributed import init_device_mesh
23
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
24
+ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
25
+ from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2Config
26
+
27
+ from verl.utils.activation_offload import enable_activation_offloading
28
+ from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
29
+ from verl.utils.device import get_device_name, get_nccl_backend, get_torch_device
30
+ from verl.utils.fsdp_utils import MixedPrecisionPolicy, apply_fsdp2, get_fsdp_wrap_policy
31
+
32
+
33
+ def create_random_input_ids(batch_size, seq_len, vocab_size):
34
+ if get_device_name() == "cuda":
35
+ from flash_attn.bert_padding import unpad_input
36
+ elif get_device_name() == "npu":
37
+ from verl.utils.attention_utils import unpad_input
38
+ from verl.utils.model import compute_position_id_with_mask, create_random_mask
39
+
40
+ input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=get_device_name())
41
+
42
+ attention_mask = create_random_mask(
43
+ input_ids, max_ratio_of_left_padding=0.1, min_ratio_of_valid_token=0.5, max_ratio_of_valid_token=0.7
44
+ )
45
+ position_ids = compute_position_id_with_mask(attention_mask)
46
+
47
+ input_ids = unpad_input(input_ids.unsqueeze(-1), attention_mask)[0].transpose(0, 1)
48
+ position_ids = unpad_input(position_ids.unsqueeze(-1), attention_mask)[0].transpose(0, 1)
49
+ return input_ids, position_ids
50
+
51
+
52
+ def _fsdp_activation_offloading_test(rank, world_size, rendezvous_file, strategy="fsdp"):
53
+ get_torch_device().set_device(rank)
54
+ torch.distributed.init_process_group(
55
+ backend=get_nccl_backend(),
56
+ init_method=f"file://{rendezvous_file}",
57
+ rank=rank,
58
+ world_size=world_size,
59
+ )
60
+ device_mesh = init_device_mesh(get_device_name(), mesh_shape=(world_size,), mesh_dim_names=("dp",))
61
+
62
+ model_name = os.path.expanduser("~/models/Qwen/Qwen2.5-0.5B-Instruct")
63
+ config = Qwen2Config(num_hidden_layers=4)
64
+
65
+ with torch.device(get_device_name()):
66
+ model = AutoModelForCausalLM.from_config(
67
+ config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
68
+ )
69
+ model = model.to(device=get_device_name())
70
+
71
+ # Wrap model with FSDP
72
+ mixed_precision = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32)
73
+
74
+ if strategy == "fsdp":
75
+ model = FSDP(
76
+ model,
77
+ use_orig_params=False,
78
+ device_id=get_torch_device().current_device(),
79
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
80
+ mixed_precision=mixed_precision,
81
+ device_mesh=device_mesh,
82
+ auto_wrap_policy=get_fsdp_wrap_policy(module=model),
83
+ )
84
+ else:
85
+ mp_policy = MixedPrecisionPolicy(
86
+ param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True
87
+ )
88
+ fsdp_kwargs = {
89
+ "mesh": device_mesh,
90
+ "mp_policy": mp_policy,
91
+ }
92
+ apply_fsdp2(model, fsdp_kwargs, {})
93
+
94
+ optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
95
+ lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
96
+
97
+ # Create checkpoint manager
98
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
99
+ checkpoint_manager = FSDPCheckpointManager(
100
+ model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, tokenizer=tokenizer
101
+ )
102
+
103
+ # Generate sample input
104
+ batch_size = 2
105
+ seq_len = 32
106
+ vocab_size = 32000
107
+ # First input for initial update
108
+ input_ids1, position_ids1 = create_random_input_ids(batch_size, seq_len, vocab_size)
109
+
110
+ # Second input for verification
111
+ input_ids2, position_ids2 = create_random_input_ids(batch_size, seq_len, vocab_size)
112
+
113
+ # Step 1: Initial update and save checkpoint
114
+ outputs1 = model(input_ids=input_ids1, position_ids=position_ids1)
115
+ loss1 = outputs1.logits.mean()
116
+ loss1.backward()
117
+ optimizer.step()
118
+ lr_scheduler.step()
119
+ optimizer.zero_grad()
120
+
121
+ # Save checkpoint after first update
122
+ temp_dir = tempfile.mkdtemp()
123
+ checkpoint_path = os.path.join(temp_dir, "checkpoint")
124
+ checkpoint_manager.save_checkpoint(local_path=checkpoint_path, hdfs_path=None, global_step=0)
125
+
126
+ # Step 2: Second update and forward pass
127
+ outputs2 = model(input_ids=input_ids2, position_ids=position_ids2)
128
+ loss2 = outputs2.logits.mean()
129
+ loss2.backward()
130
+ optimizer.step()
131
+ lr_scheduler.step()
132
+ optimizer.zero_grad()
133
+
134
+ # Record logits after second update
135
+ with torch.no_grad():
136
+ logits_without_offloading = model(input_ids=input_ids2, position_ids=position_ids2).logits
137
+
138
+ # Step 3: wrap module with activation offloading and load checkpoint
139
+ enable_activation_offloading(model, strategy=strategy)
140
+ checkpoint_manager.load_checkpoint(checkpoint_path)
141
+
142
+ # Step 4: Repeat the second update with same input
143
+ outputs3 = model(input_ids=input_ids2, position_ids=position_ids2)
144
+ loss3 = outputs3.logits.mean()
145
+ loss3.backward()
146
+ optimizer.step()
147
+ lr_scheduler.step()
148
+ optimizer.zero_grad()
149
+
150
+ # Record logits after loaded checkpoint and update
151
+ with torch.no_grad():
152
+ logits_with_offloading = model(input_ids=input_ids2, position_ids=position_ids2).logits
153
+
154
+ # Step 4: Verify outputs match
155
+ torch.testing.assert_close(logits_without_offloading, logits_with_offloading, atol=0.0, rtol=0.0)
156
+ print(f"Activaiton offloading for {strategy} test passed on {world_size} GPUs!")
157
+
158
+ # Cleanup
159
+ shutil.rmtree(temp_dir)
160
+ torch.distributed.barrier()
161
+ torch.distributed.destroy_process_group()
162
+
163
+
164
+ @pytest.mark.parametrize("world_size", (2, 4))
165
+ @pytest.mark.parametrize("strategy", ("fsdp", "fsdp2"))
166
+ def test_activation_offloading(world_size, strategy, tmp_path):
167
+ rendezvous_file = str(tmp_path / "rdzv_file")
168
+ os.makedirs(os.path.dirname(rendezvous_file), exist_ok=True)
169
+
170
+ mp.spawn(
171
+ fn=_fsdp_activation_offloading_test,
172
+ args=(world_size, rendezvous_file, strategy),
173
+ nprocs=world_size,
174
+ join=True,
175
+ )
code/RL_model/verl/verl_train/tests/utils/test_check_ipc_version_support_on_npu.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # This code is licensed under the MIT-style license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ import unittest
8
+ from unittest.mock import Mock, mock_open, patch
9
+
10
+ from verl.utils.device import check_ipc_version_support, get_npu_versions
11
+
12
+
13
+ class TestCheckIPCVersionSupport(unittest.TestCase):
14
+ """Test cases for the check_ipc_version_support function."""
15
+
16
+ def setUp(self):
17
+ """Set up test logging to suppress INFO messages."""
18
+ # Suppress INFO log messages during testing
19
+ logging.disable(logging.INFO)
20
+
21
+ def tearDown(self):
22
+ """Restore logging."""
23
+ logging.disable(logging.NOTSET)
24
+
25
+ def test_standard_version_with_support(self):
26
+ """Test standard version that meets minimum requirements."""
27
+ # Software 25.5.0 >= 25.3.rc1, CANN 8.3.0 >= 8.3.rc1
28
+ result = check_ipc_version_support("25.5.0", "8.3.0")
29
+ self.assertTrue(result)
30
+
31
+ def test_standard_version_newer(self):
32
+ """Test newer standard versions."""
33
+ # Software 26.0.0 >= 25.3.rc1, CANN 9.0.0 >= 8.3.rc1
34
+ result = check_ipc_version_support("26.0.0", "9.0.0")
35
+ self.assertTrue(result)
36
+
37
+ def test_rc_version_format(self):
38
+ """Test RC version format with additional parts."""
39
+ # Software 25.3.rc1.2 -> 25.3.rc1 >= 25.3.rc1
40
+ # CANN 8.3.rc1.2 -> 8.3.rc1 >= 8.3.rc1
41
+ result = check_ipc_version_support("25.3.rc1.2", "8.3.rc1.2")
42
+ self.assertTrue(result)
43
+
44
+ def test_exact_rc_version(self):
45
+ """Test exact RC version."""
46
+ # Software 25.3.rc1 >= 25.3.rc1
47
+ # CANN 8.3.rc1 >= 8.3.rc1
48
+ result = check_ipc_version_support("25.3.rc1", "8.3.rc1")
49
+ self.assertTrue(result)
50
+
51
+ def test_t_suffix_version(self):
52
+ """Test version with lowercase t suffix."""
53
+ # Software 25.5.t3.b001 -> 25.5 >= 25.3.rc1
54
+ # CANN 8.3.rc1 >= 8.3.rc1
55
+ result = check_ipc_version_support("25.5.t3.b001", "8.3.rc1")
56
+ self.assertTrue(result)
57
+
58
+ def test_t_suffix_version_older(self):
59
+ """Test version with lowercase t suffix that's too old."""
60
+ # Software 25.5.t3.b001 -> 25.5 >= 25.3.rc1 (should pass)
61
+ # CANN 8.2.rc1 < 8.3.rc1 (should fail)
62
+ result = check_ipc_version_support("25.5.t3.b001", "8.2.rc1")
63
+ self.assertFalse(result)
64
+
65
+ def test_software_version_below_minimum(self):
66
+ """Test software version below minimum requirement."""
67
+ # Software 25.2.0 < 25.3.rc1
68
+ result = check_ipc_version_support("25.2.0", "8.3.0")
69
+ self.assertFalse(result)
70
+
71
+ def test_cann_version_below_minimum(self):
72
+ """Test CANN version below minimum requirement."""
73
+ # Software 25.5.0 >= 25.3.rc1
74
+ # CANN 8.2.0 < 8.3.rc1
75
+ result = check_ipc_version_support("25.5.0", "8.2.0")
76
+ self.assertFalse(result)
77
+
78
+ def test_both_versions_below_minimum(self):
79
+ """Test both versions below minimum requirement."""
80
+ # Software 25.2.0 < 25.3.rc1
81
+ # CANN 8.2.0 < 8.3.rc1
82
+ result = check_ipc_version_support("25.2.0", "8.2.0")
83
+ self.assertFalse(result)
84
+
85
+ def test_invalid_software_version(self):
86
+ """Test invalid software version format."""
87
+ with self.assertRaises(RuntimeError) as context:
88
+ check_ipc_version_support("invalid.version", "8.3.0")
89
+ self.assertIn("Invalid software version format", str(context.exception))
90
+
91
+ def test_invalid_cann_version(self):
92
+ """Test invalid CANN version format."""
93
+ with self.assertRaises(RuntimeError) as context:
94
+ check_ipc_version_support("25.5.0", "invalid.version")
95
+ self.assertIn("Invalid CANN version format", str(context.exception))
96
+
97
+ def test_rc_with_more_parts(self):
98
+ """Test RC version with more than 3 parts."""
99
+ # Should extract only first 3 parts: 25.3.rc1
100
+ result = check_ipc_version_support("25.3.rc1.2.3.4", "8.3.rc1.2.3.4")
101
+ self.assertTrue(result)
102
+
103
+ def test_standard_with_more_parts(self):
104
+ """Test standard version with more than 3 parts."""
105
+ # Should extract only first 3 parts: 25.5.0
106
+ result = check_ipc_version_support("25.5.0.1.2.3", "8.3.0.1.2.3")
107
+ self.assertTrue(result)
108
+
109
+ def test_rc_edge_case_versions(self):
110
+ """Test edge case RC versions."""
111
+ # RC1 is the minimum
112
+ result = check_ipc_version_support("25.3.rc1", "8.3.rc1")
113
+ self.assertTrue(result)
114
+
115
+ # RC0 should fail
116
+ result = check_ipc_version_support("25.3.rc0", "8.3.rc1")
117
+ self.assertFalse(result)
118
+
119
+ def test_major_version_differences(self):
120
+ """Test major version number differences."""
121
+ # Much newer major versions
122
+ result = check_ipc_version_support("30.0.0", "10.0.0")
123
+ self.assertTrue(result)
124
+
125
+ # Older major versions
126
+ result = check_ipc_version_support("24.0.0", "7.0.0")
127
+ self.assertFalse(result)
128
+
129
+
130
+ class TestGetNPUVersions(unittest.TestCase):
131
+ """Test cases for the get_npu_versions function."""
132
+
133
+ @patch("subprocess.run")
134
+ @patch("platform.machine")
135
+ @patch("os.path.exists")
136
+ @patch("builtins.open", new_callable=mock_open, read_data="version=8.3.rc1\n")
137
+ def test_get_npu_versions_success(self, mock_file, mock_exists, mock_machine, mock_run):
138
+ """Test successful retrieval of versions."""
139
+ # Mock npu-smi output
140
+ mock_run.return_value = Mock(stdout="Software Version : 25.5.0\nOther Info\n", check=True)
141
+
142
+ # Mock architecture
143
+ mock_machine.return_value = "x86_64"
144
+
145
+ # Mock path exists
146
+ mock_exists.return_value = True
147
+
148
+ software_version, cann_version = get_npu_versions()
149
+
150
+ self.assertEqual(software_version, "25.5.0")
151
+ self.assertEqual(cann_version, "8.3.rc1")
152
+
153
+ @patch("subprocess.run")
154
+ def test_get_npu_versions_missing_software_version(self, mock_run):
155
+ """Test error when Software Version is missing."""
156
+ mock_run.return_value = Mock(stdout="Other Info Without Software Version\n", check=True)
157
+
158
+ with self.assertRaises(RuntimeError) as context:
159
+ get_npu_versions()
160
+
161
+ self.assertIn("Could not find Software Version", str(context.exception))
162
+
163
+ @patch("subprocess.run")
164
+ @patch("platform.machine")
165
+ @patch("os.path.exists")
166
+ @patch("builtins.open", new_callable=mock_open, read_data="version=8.3.rc1\n")
167
+ def test_get_npu_versions_unsupported_architecture(self, mock_file, mock_exists, mock_machine, mock_run):
168
+ """Test error with unsupported architecture."""
169
+ mock_run.return_value = Mock(stdout="Software Version : 25.5.0\n", check=True)
170
+
171
+ mock_machine.return_value = "armv7l" # Unsupported architecture
172
+ mock_exists.return_value = True
173
+
174
+ with self.assertRaises(RuntimeError) as context:
175
+ get_npu_versions()
176
+
177
+ self.assertIn("Unsupported architecture", str(context.exception))
178
+
179
+ @patch("subprocess.run")
180
+ @patch("platform.machine")
181
+ @patch("os.path.exists")
182
+ @patch("builtins.open", new_callable=mock_open, read_data="version=8.3.rc1\n")
183
+ def test_get_npu_versions_cann_path_not_exists(self, mock_file, mock_exists, mock_machine, mock_run):
184
+ """Test error when CANN path doesn't exist."""
185
+ mock_run.return_value = Mock(stdout="Software Version : 25.5.0\n", check=True)
186
+
187
+ mock_machine.return_value = "x86_64"
188
+ mock_exists.return_value = False # Path doesn't exist
189
+
190
+ with self.assertRaises(RuntimeError) as context:
191
+ get_npu_versions()
192
+
193
+ self.assertIn("CANN toolkit path does not exist", str(context.exception))
194
+
195
+ @patch("subprocess.run")
196
+ @patch("platform.machine")
197
+ @patch("os.path.exists")
198
+ @patch("builtins.open")
199
+ def test_get_npu_versions_info_file_not_exists(self, mock_file, mock_exists, mock_machine, mock_run):
200
+ """Test error when CANN info file doesn't exist."""
201
+ mock_run.return_value = Mock(stdout="Software Version : 25.5.0\n", check=True)
202
+
203
+ mock_machine.return_value = "x86_64"
204
+
205
+ # First call is for CANN path exists, second call is for info file exists
206
+ mock_exists.side_effect = [True, False]
207
+
208
+ with self.assertRaises(RuntimeError) as context:
209
+ get_npu_versions()
210
+
211
+ self.assertIn("CANN toolkit info file does not exist", str(context.exception))
212
+
213
+ @patch("subprocess.run")
214
+ @patch("platform.machine")
215
+ @patch("os.path.exists")
216
+ @patch("builtins.open", new_callable=mock_open, read_data="other_info=no_version\n")
217
+ def test_get_npu_versions_missing_cann_version(self, mock_file, mock_exists, mock_machine, mock_run):
218
+ """Test error when CANN version is missing from info file."""
219
+ mock_run.return_value = Mock(stdout="Software Version : 25.5.0\n", check=True)
220
+
221
+ mock_machine.return_value = "x86_64"
222
+ mock_exists.return_value = True
223
+
224
+ with self.assertRaises(RuntimeError) as context:
225
+ get_npu_versions()
226
+
227
+ self.assertIn("Could not find version in CANN toolkit info file", str(context.exception))
228
+
229
+
230
+ if __name__ == "__main__":
231
+ unittest.main()
code/RL_model/verl/verl_train/tests/utils/test_config_on_cpu.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import unittest
16
+ from dataclasses import dataclass, field
17
+
18
+ from omegaconf import OmegaConf
19
+
20
+ from verl.base_config import BaseConfig
21
+ from verl.utils import omega_conf_to_dataclass
22
+
23
+
24
+ @dataclass
25
+ class TestDataclass(BaseConfig):
26
+ hidden_size: int = 0
27
+ activation: str = "relu"
28
+
29
+
30
+ @dataclass
31
+ class TestTrainConfig(BaseConfig):
32
+ batch_size: int = 0
33
+ model: TestDataclass = field(default_factory=TestDataclass)
34
+ override_config: dict = field(default_factory=dict)
35
+
36
+
37
+ _cfg_str = """train_config:
38
+ _target_: tests.utils.test_config_on_cpu.TestTrainConfig
39
+ batch_size: 32
40
+ model:
41
+ hidden_size: 768
42
+ activation: relu
43
+ override_config: {}"""
44
+
45
+
46
+ class TestConfigOnCPU(unittest.TestCase):
47
+ """Test cases for configuration utilities on CPU.
48
+
49
+ Test Plan:
50
+ 1. Test basic OmegaConf to dataclass conversion for simple nested structures
51
+ 2. Test nested OmegaConf to dataclass conversion for complex hierarchical configurations
52
+ 3. Verify all configuration values are correctly converted and accessible
53
+ """
54
+
55
+ def setUp(self):
56
+ self.config = OmegaConf.create(_cfg_str)
57
+
58
+ def test_omega_conf_to_dataclass(self):
59
+ sub_cfg = self.config.train_config.model
60
+ cfg = omega_conf_to_dataclass(sub_cfg, TestDataclass)
61
+ self.assertEqual(cfg.hidden_size, 768)
62
+ self.assertEqual(cfg.activation, "relu")
63
+ assert isinstance(cfg, TestDataclass)
64
+
65
+ def test_nested_omega_conf_to_dataclass(self):
66
+ cfg = omega_conf_to_dataclass(self.config.train_config, TestTrainConfig)
67
+ self.assertEqual(cfg.batch_size, 32)
68
+ self.assertEqual(cfg.model.hidden_size, 768)
69
+ self.assertEqual(cfg.model.activation, "relu")
70
+ assert isinstance(cfg, TestTrainConfig)
71
+ assert isinstance(cfg.model, TestDataclass)
72
+
73
+
74
+ class TestPrintCfgCommand(unittest.TestCase):
75
+ """Test suite for the print_cfg.py command-line tool."""
76
+
77
+ def test_command_with_override(self):
78
+ """Test that the command runs without error when overriding config values."""
79
+ import subprocess
80
+
81
+ # Run the command
82
+ result = subprocess.run(
83
+ ["python3", "scripts/print_cfg.py"],
84
+ capture_output=True,
85
+ text=True,
86
+ )
87
+
88
+ # Verify the command exited successfully
89
+ self.assertEqual(result.returncode, 0, f"Command failed with stderr: {result.stderr}")
90
+
91
+ # Verify the output contains expected config information
92
+ self.assertIn("critic", result.stdout)
93
+ self.assertIn("profiler", result.stdout)
94
+
95
+
96
+ if __name__ == "__main__":
97
+ unittest.main()
code/RL_model/verl/verl_train/tests/utils/test_flops_counter.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+
17
+ import pytest
18
+
19
+ from verl.utils.flops_counter import FlopsCounter
20
+
21
+ VALID_CONFIG_TYPE = {"llama", "qwen2", "qwen3", "qwen3_moe", "deepseek_v3", "mistral", "gemma3_text", "apertus"}
22
+
23
+
24
+ class Config:
25
+ def __init__(self, config_dict):
26
+ for key, value in config_dict.items():
27
+ if isinstance(value, dict):
28
+ value = Config(value)
29
+ setattr(self, key, value)
30
+
31
+
32
+ CONFIG = {
33
+ "llama": {
34
+ "config": { # llama2-7B
35
+ "model_type": "llama",
36
+ "vocab_size": 32000,
37
+ "hidden_size": 4096,
38
+ "intermediate_size": 11008,
39
+ "num_hidden_layers": 32,
40
+ "num_attention_heads": 32,
41
+ "num_key_value_heads": 32,
42
+ },
43
+ "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]),
44
+ # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+ hidden*inter*3))*token_sum +
45
+ # 6*sum(seqlen^2)*layer*head*head_dim
46
+ # 6*(32000*4096*2+32*(4096*4096*4+4096*11008*3))*(512+1024+2048) +
47
+ # 6*(512*512+1024*1024+2048*2048)*32*4096
48
+ # 6*(32000*4096*2+32*(4096*4096*4+4096*11008*3))*(4096+4096+4096) +
49
+ # 6*(4096*4096+4096*4096+4096*4096)*32*4096
50
+ "expected_flops_tuple": (149226491215872 / 1e12, 536372695793664 / 1e12),
51
+ },
52
+ "qwen2": {
53
+ "config": { # Qwen/Qwen2.5-7B-Instruct
54
+ "model_type": "qwen2",
55
+ "vocab_size": 152064,
56
+ "hidden_size": 3584,
57
+ "intermediate_size": 18944,
58
+ "num_hidden_layers": 28,
59
+ "num_attention_heads": 28,
60
+ "num_key_value_heads": 4,
61
+ },
62
+ "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]),
63
+ # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+ hidden*inter*3))*token_sum +
64
+ # 6*sum(seqlen^2)*layer*head*head_dim
65
+ # 6*(152064*3584*2+28*(3584*(3584+512+512+3584)+3584*18944*3))*(512+1024+2048) +
66
+ # 6*(512*512+1024*1024+2048*2048)*28*3584
67
+ # 6*(152064*3584*2+28*(3584*(3584+512+512+3584)+3584*18944*3))*(4096+4096+4096) +
68
+ # 6*(4096*4096+4096*4096+4096*4096)*28*3584
69
+ "expected_flops_tuple": (167073690943488 / 1e12, 591764889010176 / 1e12),
70
+ },
71
+ "qwen3": {
72
+ "config": { # Qwen/Qwen3-8B
73
+ "model_type": "qwen3",
74
+ "vocab_size": 151936,
75
+ "hidden_size": 4096,
76
+ "intermediate_size": 12288,
77
+ "num_hidden_layers": 36,
78
+ "num_attention_heads": 32,
79
+ "num_key_value_heads": 8,
80
+ "head_dim": 128,
81
+ },
82
+ "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]),
83
+ # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+ hidden*inter*3))*token_sum +
84
+ # 6*sum(seqlen^2)*layer*head*head_dim
85
+ # 6*(151936*4096*2+36*(4096*(128*32+128*8*2+128*32)+4096*12288*3))*(512+1024+2048) +
86
+ # 6*(512*512+1024*1024+2048*2048)*36*128*32
87
+ # 6*(151936*4096*2+36*(4096*(128*32+128*8*2+128*32)+4096*12288*3))*(4096+4096+4096) +
88
+ # 6*(4096*4096+4096*4096+4096*4096)*36*128*32
89
+ "expected_flops_tuple": (180997438046208 / 1e12, 648394032807936 / 1e12),
90
+ },
91
+ "qwen3_moe": {
92
+ "config": { # Qwen/Qwen3-30B-A3B-Base
93
+ "model_type": "qwen3_moe",
94
+ "hidden_size": 2048,
95
+ "vocab_size": 151936,
96
+ "num_hidden_layers": 48,
97
+ "num_key_value_heads": 4,
98
+ "num_attention_heads": 32,
99
+ "head_dim": 128,
100
+ "moe_intermediate_size": 768,
101
+ "num_experts_per_tok": 8,
102
+ "num_experts": 128,
103
+ },
104
+ "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]),
105
+ # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+hidden*inter*top_k_exp*3 +
106
+ # hidden*num_experts))*token_sum + 6*sum(seqlen^2)*layer*head*head_dim
107
+ # 6*(151936*2048*2+48*(2048*(128*32+128*4*2+128*32)+2048*768*8*3+2048*128))*(512+1024+2048) +
108
+ # 6*(512*512+1024*1024+2048*2048)*48*128*32
109
+ # 6*(151936*2048*2+48*(2048*(128*32+128*4*2+128*32)+2048*768*8*3+2048*128))*(4096+4096+4096) +
110
+ # 6*(4096*4096+4096*4096+4096*4096)*48*128*32
111
+ "expected_flops_tuple": (78593069678592 / 1e12, 306570470621184 / 1e12),
112
+ },
113
+ "deepseek_v3": {
114
+ "config": { # deepseek-ai/DeepSeek-Prover-V2-671B
115
+ "model_type": "deepseek_v3",
116
+ "hidden_size": 7168,
117
+ "vocab_size": 129280,
118
+ "moe_intermediate_size": 2048,
119
+ "num_hidden_layers": 61,
120
+ "first_k_dense_replace": 3,
121
+ "num_attention_heads": 128,
122
+ "n_routed_experts": 256,
123
+ "num_experts_per_tok": 8,
124
+ "n_shared_experts": 1,
125
+ "kv_lora_rank": 512,
126
+ "qk_rope_head_dim": 64,
127
+ "v_head_dim": 128,
128
+ "intermediate_size": 18432,
129
+ "qk_nope_head_dim": 128,
130
+ "q_lora_rank": 1536,
131
+ },
132
+ "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]),
133
+ # (1536*7168+128*192*1536+7168*(512+64)+128*(128+128)*512+128*128*7168) = 187105280
134
+ # 6*(129280*7168*2+ 3*(7168*18432*3+187105280)+ 58*(187105280+7168*256+7168*2048*9*3))*(512+1024+2048) +
135
+ # 3*(512*512+1024*1024+2048*2048)*61*(192+128)*128
136
+ # 6*(129280*7168*2+ 3*(7168*18432*3+187105280)+ 58*(187105280+7168*256+7168*2048*9*3))*(4096+4096+4096) +
137
+ # 3*(4096*4096+4096*4096+4096*4096)*61*(192+128)*128
138
+ "expected_flops_tuple": (848766538088448 / 1e12, 3145850406567936 / 1e12),
139
+ },
140
+ "mistral": {
141
+ "config": { # mistralai/Mistral-Small-24B-Instruct-2501
142
+ "model_type": "mistral",
143
+ "vocab_size": 131072,
144
+ "hidden_size": 5120,
145
+ "intermediate_size": 32768,
146
+ "num_hidden_layers": 40,
147
+ "num_attention_heads": 32,
148
+ "num_key_value_heads": 8,
149
+ "head_dim": 128,
150
+ },
151
+ "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]),
152
+ # Mistral uses same architecture as Llama, with GQA
153
+ # 6*(vocab*hidden*2+layer*(hidden*(q+k+v+head*head_dim)+ hidden*inter*3))*token_sum +
154
+ # 12*sum(seqlen^2)*layer*head*head_dim
155
+ # vocab part: 131072*5120*2 = 1342177280
156
+ # attn part per layer: 5120*(128*32+128*8+128*8+128*32) = 5120*10240 = 52428800
157
+ # mlp part per layer: 5120*32768*3 = 503316480
158
+ # total per layer: 52428800 + 503316480 = 555745280
159
+ # all layers: 1342177280 + 40*555745280 = 23571988480
160
+ # For batch [512, 1024, 2048], tokens_sum = 3584:
161
+ # dense flops: 6 * 23571988480 * 3584 = 506892040273920
162
+ # attn flops: 6 * 5505024 * 40 * 128 * 32 = 10823317585920
163
+ # total: 517715357859840 / 1e12 = 517.71535785984
164
+ # For batch [4096, 4096, 4096], tokens_sum = 12288:
165
+ # dense flops: 6 * 23571988480 * 12288 = 1737915566653440
166
+ # attn flops: 6 * 50331648 * 40 * 128 * 32 = 98956046499840
167
+ # total: 1836871613153280 / 1e12 = 1836.87161315328
168
+ "expected_flops_tuple": (512303699066880 / 1e12, 1787393589903360 / 1e12),
169
+ },
170
+ "gemma3_text": {
171
+ "config": { # Gemma3-12B-IT-TextOnly
172
+ "model_type": "gemma3_text",
173
+ "vocab_size": 262208,
174
+ "hidden_size": 3840,
175
+ "intermediate_size": 15360,
176
+ "num_hidden_layers": 48,
177
+ "num_attention_heads": 16,
178
+ "num_key_value_heads": 8,
179
+ "head_dim": 256,
180
+ "sliding_window": 1024,
181
+ "layer_types": None,
182
+ # Will be auto-generated based on sliding_window_pattern
183
+ "sliding_window_pattern": 6,
184
+ # Every 6th layer is full attention
185
+ },
186
+ "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]),
187
+ # Gemma3 has alternating sliding window attention
188
+ # With sliding_window_pattern=6: layers 5,11,17,23,29,35,41,47 use full attention (8 layers)
189
+ # Other 40 layers use sliding window attention with window_size=1024
190
+ #
191
+ # Non-attention FLOPs:
192
+ # vocab part: 262208*3840*2 = 2013757440
193
+ # attn part per layer: 3840*(256*16+256*8+256*8+256*16) = 3840*12288 = 47185920
194
+ # mlp part per layer: 3840*15360*3 = 176947200
195
+ # total per layer: 47185920 + 176947200 = 224133120
196
+ # all layers: 2013757440 + 48*224133120 = 12772147200
197
+ #
198
+ # For batch [512, 1024, 2048], tokens_sum = 3584:
199
+ # dense flops: 6 * 12772147200 * 3584 = 274652253388800
200
+ # seqlen_square_sum: 180355072 (calculated with sliding window logic)
201
+ # attn flops: 6 * 180355072 * 256 * 16 = 8864812498944
202
+ # total: 283517065887744 / 1e12 = 283.517065887744
203
+ #
204
+ # For batch [4096, 4096, 4096], tokens_sum = 12288:
205
+ # dense flops: 6 * 12772147200 * 12288 = 941664868761600
206
+ # seqlen_square_sum: 905969664 (calculated with sliding window logic)
207
+ # attn flops: 6 * 905969664 * 256 * 16 = 44530220924928
208
+ # total: 986195089686528 / 1e12 = 986.195089686528
209
+ "expected_flops_tuple": (279084659638272 / 1e12, 963929979224064 / 1e12),
210
+ },
211
+ "gpt_oss": {
212
+ "config": {
213
+ "model_type": "gpt_oss",
214
+ "vocab_size": 201088,
215
+ "hidden_size": 2880,
216
+ "num_hidden_layers": 24,
217
+ "num_attention_heads": 64,
218
+ "num_key_value_heads": 8,
219
+ "head_dim": 64,
220
+ "intermediate_size": 2880,
221
+ "num_local_experts": 32,
222
+ "num_experts_per_tok": 4,
223
+ "sliding_window": 128,
224
+ "layer_types": [
225
+ "sliding_attention",
226
+ "full_attention",
227
+ "sliding_attention",
228
+ "full_attention",
229
+ "sliding_attention",
230
+ "full_attention",
231
+ "sliding_attention",
232
+ "full_attention",
233
+ "sliding_attention",
234
+ "full_attention",
235
+ "sliding_attention",
236
+ "full_attention",
237
+ "sliding_attention",
238
+ "full_attention",
239
+ "sliding_attention",
240
+ "full_attention",
241
+ "sliding_attention",
242
+ "full_attention",
243
+ "sliding_attention",
244
+ "full_attention",
245
+ "sliding_attention",
246
+ "full_attention",
247
+ "sliding_attention",
248
+ "full_attention",
249
+ ],
250
+ },
251
+ "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]),
252
+ # GPT-OSS has alternating sliding / full attention
253
+ # Even layers (12 layers) use sliding window attention with window_size = 128
254
+ # Odd layers (12 layers) use full attention
255
+ #
256
+ # Non-attention FLOPs:
257
+ # vocab part: 201088 * 2880 * 2 = 1158266880
258
+ # attn linear part per layer:
259
+ # Q: 2880 * (64 * 64) = 11796480
260
+ # K: 2880 * (8 * 64) = 1474560
261
+ # V: 2880 * (8 * 64) = 1474560
262
+ # O: (64 * 64) * 2880 = 11796480
263
+ # attn linear total = 26542080
264
+ # mlp (MoE, SwiGLU) part per layer:
265
+ # gate: 2880 * 32 = 92160
266
+ # active experts: 3 * 2880 * 2880 * 4 = 99532800
267
+ # mlp total = 99624960
268
+ # total per layer: 26542080 + 99624960 = 126167040
269
+ # all layers:
270
+ # 126167040 * 24 = 3028008960
271
+ # total dense params:
272
+ # 3028008960 + 1158266880 = 4186275840
273
+ #
274
+ # For batch [512, 1024, 2048], tokens_sum = 3584:
275
+ # dense flops: 6 * 4186275840 * 3584 = 90021675663360
276
+ # seqlen_square_sum: 71565312 (calculated with sliding window logic)
277
+ # attn flops: 6 * 71565312 * 64 * 64 = 3517578215424
278
+ # total: 93539253878784 / 1e12 = 93.539253878784
279
+ #
280
+ # For batch [4096, 4096, 4096], tokens_sum = 12288:
281
+ # dense flops: 6 * 4186275840 * 12288 = 308646629068800
282
+ # seqlen_square_sum: 622854144 (calculated with sliding window logic)
283
+ # attn flops: 6 * 622854144 * 64 * 64 = 30613642948608
284
+ # total: 339260272017408 / 1e12 = 339.260272017408
285
+ "expected_flops_tuple": (91780464771072 / 1e12, 323953008574464 / 1e12),
286
+ },
287
+ "apertus": {
288
+ "config": { # swiss-ai/Apertus-8B
289
+ "model_type": "apertus",
290
+ "vocab_size": 131072,
291
+ "hidden_size": 4096,
292
+ "intermediate_size": 21504,
293
+ "num_hidden_layers": 32,
294
+ "num_attention_heads": 32,
295
+ "num_key_value_heads": 32,
296
+ "hidden_act": "xielu",
297
+ # head_dim will be derived as 4096 / 32 = 128
298
+ },
299
+ "batch_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]),
300
+ # Calculation for Apertus (hidden_act="xielu" -> MLP uses [k_mlp=2]*H*I params; qk_norm=True -> [k_qkn=2]*H):
301
+ # V=131072, H=4096, I=21504, L=32, k_mlp=2 (XIELU), k_qkn=2 (QK norm), S=6
302
+ # S*(2*V*H + L*(4*H**2 + k_mlp*H*I + k_qkn*H)) * (SUM[seqlen]) + 6*SUM[seqlen**2]*L*H
303
+ "expected_flops_tuple": (194825353691136 / 1e12, 692711652851712 / 1e12),
304
+ },
305
+ "qwen3_vl": {
306
+ "config": { # Qwen/Qwen3-VL-8B
307
+ "model_type": "qwen3_vl",
308
+ # -------- Text config --------
309
+ "text_config": {
310
+ "vocab_size": 151936,
311
+ "hidden_size": 4096,
312
+ "intermediate_size": 12288,
313
+ "num_hidden_layers": 36,
314
+ "num_attention_heads": 32,
315
+ "num_key_value_heads": 8,
316
+ "head_dim": 128,
317
+ },
318
+ # -------- Vision config (ViT) --------
319
+ "vision_config": {
320
+ "deepstack_visual_indexes": [8, 16, 24],
321
+ "num_heads": 16,
322
+ "depth": 27,
323
+ "hidden_size": 1152,
324
+ "intermediate_size": 4304,
325
+ "out_hidden_size": 4096,
326
+ "spatial_merge_size": 2,
327
+ "temporal_patch_size": 2,
328
+ "in_channels": 3,
329
+ "patch_size": 16,
330
+ },
331
+ },
332
+ "batch_seqlens_tuple": (
333
+ [512, 1024, 2048],
334
+ [4096, 4096, 4096],
335
+ ),
336
+ "images_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]),
337
+ # -----Text-----
338
+ # 6*(vocab*hidden*2
339
+ # + layer*(hidden*(q+k+v+o) + hidden*inter*3)
340
+ # )*token_sum
341
+ # + 6*sum(seqlen^2)*layer*hidden
342
+ #
343
+ # -----ViT-----
344
+ # patch_embed_N =hidden*temporal_patch_size*in_channels* patch_size^2
345
+ # attn_linear_N =hidden*(4*hidden)
346
+ # mlp_N =hidden*inter*2
347
+ # merger_N =((o+hidden*spatial_merge_size^2) * (hidden*spatial_merge_size^2))
348
+ # deepstack_merger_N =merger_N * 3
349
+ # dense_N =patch_embed_N + (attn_linear_N + mlp_N) * 27 + deepstack_merger_N + merger_N
350
+ #
351
+ # 6*(151936*4096*2
352
+ # + 36*(4096*(4096+1024+1024+4096) + 4096*12288*3)
353
+ # )*(512+1024+2048)
354
+ # + 12*(512*512+1024*1024+2048*2048)*36*4096
355
+ # + 6 * dense_N * (512 + 1024 + 2048)
356
+ # + 12 * (512**2 + 1024**2 + 2048**2) * 27 * 16 * 72
357
+ #
358
+ # 6*(151936*4096*2
359
+ # + 36*(4096*(4096+1024+1024+4096) + 4096*12288*3)
360
+ # )*(4096+4096+4096)
361
+ # + 12*(4096*4096+4096*4096+4096*4096)*36*4096
362
+ # + 6 * dense_N * (4096 + 4096 + 2048)
363
+ # + 12 * (4096**2 + 4096**2 + 4096**2) * 27 * 16 * 72
364
+ "expected_flops_tuple": (
365
+ 195379819708416 / 1e12,
366
+ 709446422495232 / 1e12,
367
+ ),
368
+ },
369
+ "qwen3_vl_moe": {
370
+ "config": { # Qwen/Qwen3-VL-30B-A3B
371
+ "model_type": "qwen3_vl_moe",
372
+ # -------- Text config --------
373
+ "text_config": {
374
+ "vocab_size": 151936,
375
+ "hidden_size": 2048,
376
+ "num_hidden_layers": 48,
377
+ "num_attention_heads": 32,
378
+ "num_key_value_heads": 4,
379
+ "head_dim": 128,
380
+ "moe_intermediate_size": 768,
381
+ "num_experts": 128,
382
+ "num_experts_per_tok": 8,
383
+ },
384
+ # -------- Vision config (ViT) --------
385
+ "vision_config": {
386
+ "deepstack_visual_indexes": [8, 16, 24],
387
+ "num_heads": 16,
388
+ "depth": 27,
389
+ "hidden_size": 1152,
390
+ "intermediate_size": 4304,
391
+ "out_hidden_size": 4096,
392
+ "spatial_merge_size": 2,
393
+ "temporal_patch_size": 2,
394
+ "in_channels": 3,
395
+ "patch_size": 16,
396
+ },
397
+ },
398
+ "batch_seqlens_tuple": (
399
+ [512, 1024, 2048],
400
+ [4096, 4096, 4096],
401
+ ),
402
+ "images_seqlens_tuple": ([512, 1024, 2048], [4096, 4096, 4096]),
403
+ # -----Text-----
404
+ # 6*(vocab*hidden*2
405
+ # + layer*(hidden*(q+k+v+head*head_dim)+hidden*inter*top_k_exp*3+hidden*num_experts)
406
+ # )*token_sum
407
+ # + 6*sum(seqlen^2)*layer*hidden
408
+ #
409
+ # -----ViT-----
410
+ # patch_embed_N =hidden*temporal_patch_size*in_channels* patch_size^2
411
+ # attn_linear_N =hidden*(4*hidden)
412
+ # mlp_N =hidden*inter*2
413
+ # merger_N =((o+hidden*spatial_merge_size^2) * (hidden*spatial_merge_size^2))
414
+ # deepstack_merger_N =merger_N * 3
415
+ # dense_N =patch_embed_N + (attn_linear_N + mlp_N) * 27 + deepstack_merger_N + merger_N
416
+ #
417
+ # 6*(151936*2048*2
418
+ # + 48*(2048*(128*32+128*4*2+128*32)+2048*768*8*3+2048*128)
419
+ # )*(512+1024+2048)
420
+ # + 12*(512*512+1024*1024+2048*2048)*48*4096
421
+ # + 6 * dense_N * (512 + 1024 + 2048)
422
+ # + 12 * (512**2 + 1024**2 + 2048**2) * 27 * 16 * 72
423
+ #
424
+ # 6*(151936*2048*2
425
+ # 48*(2048*(128*32+128*4*2+128*32)+2048*768*8*3+2048*128)
426
+ # )*(4096+4096+4096)
427
+ # + 12*(4096*4096+4096*4096+4096*4096)*48*4096
428
+ # + 6 * dense_N * (4096 + 4096 + 2048)
429
+ # + 12 * (4096**2 + 4096**2 + 4096**2) * 27 * 16 * 72
430
+ "expected_flops_tuple": (
431
+ 92975451340800 / 1e12,
432
+ 367622860308480 / 1e12,
433
+ ),
434
+ },
435
+ }
436
+
437
+
438
+ @pytest.mark.parametrize(
439
+ "config_type",
440
+ [
441
+ "llama",
442
+ "qwen2",
443
+ "qwen3",
444
+ "qwen3_moe",
445
+ "deepseek_v3",
446
+ "mistral",
447
+ "gemma3_text",
448
+ "apertus",
449
+ "gpt_oss",
450
+ "qwen3_vl",
451
+ "qwen3_vl_moe",
452
+ ],
453
+ )
454
+ def test_flops_counter(config_type: str):
455
+ test_config = CONFIG[config_type]
456
+ config = Config(test_config["config"])
457
+ flops_counter = FlopsCounter(config)
458
+ if "images_seqlens_tuple" in test_config:
459
+ for batch_seqlens, images_seqlens, expected_flops in zip(
460
+ test_config["batch_seqlens_tuple"],
461
+ test_config["images_seqlens_tuple"],
462
+ test_config["expected_flops_tuple"],
463
+ strict=True,
464
+ ):
465
+ # set delta time to 1 to get the flops
466
+ counted_flops, _ = flops_counter.estimate_flops(batch_seqlens, 1, images_seqlens=images_seqlens)
467
+ print(f"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}")
468
+ assert math.isclose(counted_flops, expected_flops), (
469
+ f"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}"
470
+ )
471
+ else:
472
+ for batch_seqlens, expected_flops in zip(
473
+ test_config["batch_seqlens_tuple"], test_config["expected_flops_tuple"], strict=True
474
+ ):
475
+ # set delta time to 1 to get the flops
476
+ counted_flops, _ = flops_counter.estimate_flops(batch_seqlens, 1)
477
+ print(f"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}")
478
+ assert math.isclose(counted_flops, expected_flops), (
479
+ f"Expect flops for {test_config['config']} is {expected_flops}, but get {counted_flops}"
480
+ )
code/RL_model/verl/verl_train/tests/utils/test_fs_on_cpu.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ from pathlib import Path
17
+
18
+ import verl.utils.fs as fs
19
+
20
+
21
+ def test_record_and_check_directory_structure(tmp_path):
22
+ # Create test directory structure
23
+ test_dir = tmp_path / "test_dir"
24
+ test_dir.mkdir()
25
+ (test_dir / "file1.txt").write_text("test")
26
+ (test_dir / "subdir").mkdir()
27
+ (test_dir / "subdir" / "file2.txt").write_text("test")
28
+
29
+ # Create structure record
30
+ record_file = fs._record_directory_structure(test_dir)
31
+
32
+ # Verify record file exists
33
+ assert os.path.exists(record_file)
34
+
35
+ # Initial check should pass
36
+ assert fs._check_directory_structure(test_dir, record_file) is True
37
+
38
+ # Modify structure and verify check fails
39
+ (test_dir / "new_file.txt").write_text("test")
40
+ assert fs._check_directory_structure(test_dir, record_file) is False
41
+
42
+
43
+ def test_copy_from_hdfs_with_mocks(tmp_path, monkeypatch):
44
+ # Mock HDFS dependencies
45
+ monkeypatch.setattr(fs, "is_non_local", lambda path: True)
46
+
47
+ # side_effect will simulate the copy by creating parent dirs + empty file
48
+ def fake_copy(src: str, dst: str, *args, **kwargs):
49
+ dst_path = Path(dst)
50
+ dst_path.parent.mkdir(parents=True, exist_ok=True)
51
+ dst_path.write_bytes(b"") # touch an empty file
52
+
53
+ monkeypatch.setattr(fs, "copy", fake_copy) # Mock actual HDFS copy
54
+
55
+ # Test parameters
56
+ test_cache = tmp_path / "cache"
57
+ hdfs_path = "hdfs://test/path/file.txt"
58
+
59
+ # Test initial copy
60
+ local_path = fs.copy_to_local(hdfs_path, cache_dir=test_cache)
61
+ expected_path = os.path.join(test_cache, fs.md5_encode(hdfs_path), os.path.basename(hdfs_path))
62
+ assert local_path == expected_path
63
+ assert os.path.exists(local_path)
64
+
65
+
66
+ def test_always_recopy_flag(tmp_path, monkeypatch):
67
+ # Mock HDFS dependencies
68
+ monkeypatch.setattr(fs, "is_non_local", lambda path: True)
69
+
70
+ copy_call_count = 0
71
+
72
+ def fake_copy(src: str, dst: str, *args, **kwargs):
73
+ nonlocal copy_call_count
74
+ copy_call_count += 1
75
+ dst_path = Path(dst)
76
+ dst_path.parent.mkdir(parents=True, exist_ok=True)
77
+ dst_path.write_bytes(b"")
78
+
79
+ monkeypatch.setattr(fs, "copy", fake_copy) # Mock actual HDFS copy
80
+
81
+ test_cache = tmp_path / "cache"
82
+ hdfs_path = "hdfs://test/path/file.txt"
83
+
84
+ # Initial copy (always_recopy=False)
85
+ fs.copy_to_local(hdfs_path, cache_dir=test_cache)
86
+ assert copy_call_count == 1
87
+
88
+ # Force recopy (always_recopy=True)
89
+ fs.copy_to_local(hdfs_path, cache_dir=test_cache, always_recopy=True)
90
+ assert copy_call_count == 2
91
+
92
+ # Subsequent normal call (always_recopy=False)
93
+ fs.copy_to_local(hdfs_path, cache_dir=test_cache)
94
+ assert copy_call_count == 2 # Should not increment
code/RL_model/verl/verl_train/tests/utils/test_groupwise.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ # Copyright 2023-2024 SGLang Team
3
+ # Copyright 2025 ModelBest Inc. and/or its affiliates
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import os
17
+
18
+ os.environ.setdefault("VERL_FORCE_DEVICE", "cpu") # ensure CPU for tests
19
+
20
+ import numpy as np
21
+ import pytest
22
+ import torch
23
+
24
+ from verl.utils import as_torch_index, group_mean_std
25
+
26
+
27
+ def test_as_torch_index_basic_integers():
28
+ g = as_torch_index([2, 2, 5, 7, 5, 2])
29
+ assert g.dtype == torch.long
30
+ assert g.device.type == "cpu"
31
+ # Values should be contiguous 0..G-1, keeping equal labels equal
32
+ assert g.tolist()[0] == g.tolist()[1]
33
+ assert len(torch.unique(g)) == 3 # {2,5,7} -> 3 groups
34
+
35
+
36
+ def test_as_torch_index_near_integer_floats():
37
+ arr = np.array([1.0000001, 2.0, 1.0, 3.0000000001], dtype=np.float64)
38
+ g = as_torch_index(arr) # should round to integers then factorize
39
+ assert g.dtype == torch.long
40
+ assert len(torch.unique(g)) == 3 # {1,2,3}
41
+
42
+
43
+ def test_as_torch_index_factorization_mixed():
44
+ labels = ["a", "b", "a", "c", "0042", 42]
45
+ g = as_torch_index(labels)
46
+ # "0042" and 42 should NOT be the same group (strings are not coerced here)
47
+ assert g.tolist()[4] != g.tolist()[5]
48
+ assert len(torch.unique(g)) == 5
49
+
50
+
51
+ def test_group_mean_std_simple():
52
+ # groups: 0 -> [1, 3], 1 -> [2]
53
+ scores = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
54
+ gidx = as_torch_index([0, 1, 0])
55
+
56
+ mean_g, std_g, cnt_g = group_mean_std(scores, gidx)
57
+ # group 0: mean = (1+3)/2 = 2
58
+ # sample std (unbiased) = sqrt( (sum(x^2) - (sum(x)^2)/n) / (n-1) )
59
+ # = sqrt( (1^2+3^2) - (1+3)^2/2 ) / (2-1) = sqrt(10 - 16/2) = sqrt(2)
60
+ assert torch.allclose(mean_g, torch.tensor([2.0, 0.0]))
61
+ assert torch.allclose(cnt_g, torch.tensor([2.0, 1.0]))
62
+ # singleton group -> std = 1.0
63
+ assert mean_g[1].item() == 0.0
64
+ assert std_g[1].item() == 1.0
65
+ assert pytest.approx(std_g[0].item(), rel=1e-6) == (2.0**0.5)
66
+
67
+
68
+ def test_group_mean_std_empty():
69
+ scores = torch.tensor([], dtype=torch.float32)
70
+ gidx = torch.tensor([], dtype=torch.long)
71
+ mean_g, std_g, cnt_g = group_mean_std(scores, gidx)
72
+ assert mean_g.numel() == 0 and std_g.numel() == 0 and cnt_g.numel() == 0
73
+
74
+
75
+ def test_group_mean_std_default_device_no_force_env(monkeypatch):
76
+ """
77
+ Regression test:
78
+ - group_mean_std(device=None) must not pass a device *module* (e.g., torch.cuda)
79
+ into Tensor.to(device=...), which crashes with:
80
+ TypeError: to() received an invalid combination of arguments - got (..., device=module, ...)
81
+ """
82
+ # Simulate a non-pytest environment (training code path) while keeping the test CPU-only.
83
+ monkeypatch.delenv("VERL_FORCE_DEVICE", raising=False)
84
+ monkeypatch.delenv("PYTEST_CURRENT_TEST", raising=False)
85
+
86
+ # Force device selection to CPU even if CUDA is available on the test machine.
87
+ import verl.utils.device as device_mod
88
+
89
+ monkeypatch.setattr(device_mod, "is_cuda_available", False)
90
+ monkeypatch.setattr(device_mod, "is_npu_available", False)
91
+
92
+ scores = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
93
+ gidx = torch.tensor([0, 1, 0], dtype=torch.long)
94
+
95
+ mean_g, std_g, cnt_g = group_mean_std(scores, gidx)
96
+ assert mean_g.device.type == "cpu"
97
+ assert std_g.device.type == "cpu"
98
+ assert cnt_g.device.type == "cpu"
code/RL_model/verl/verl_train/tests/utils/test_import_utils_on_cpu.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+
17
+ import pytest
18
+
19
+ from verl.utils.import_utils import load_extern_object
20
+
21
+ # Path to the test module
22
+ TEST_MODULE_PATH = os.path.join(os.path.dirname(__file__), "_test_module.py")
23
+
24
+
25
+ def test_load_extern_object_class():
26
+ """Test loading a class from an external file"""
27
+ TestClass = load_extern_object(TEST_MODULE_PATH, "TestClass")
28
+
29
+ # Verify the class was loaded correctly
30
+ assert TestClass is not None
31
+ assert TestClass.__name__ == "TestClass"
32
+
33
+ # Test instantiation and functionality
34
+ instance = TestClass()
35
+ assert instance.value == "default"
36
+
37
+ # Test with a custom value
38
+ custom_instance = TestClass("custom")
39
+ assert custom_instance.get_value() == "custom"
40
+
41
+
42
+ def test_load_extern_object_function():
43
+ """Test loading a function from an external file"""
44
+ test_function = load_extern_object(TEST_MODULE_PATH, "test_function")
45
+
46
+ # Verify the function was loaded correctly
47
+ assert test_function is not None
48
+ assert callable(test_function)
49
+
50
+ # Test function execution
51
+ result = test_function()
52
+ assert result == "test_function_result"
53
+
54
+
55
+ def test_load_extern_object_constant():
56
+ """Test loading a constant from an external file"""
57
+ constant = load_extern_object(TEST_MODULE_PATH, "TEST_CONSTANT")
58
+
59
+ # Verify the constant was loaded correctly
60
+ assert constant is not None
61
+ assert constant == "test_constant_value"
62
+
63
+
64
+ def test_load_extern_object_nonexistent_file():
65
+ """Test behavior when file doesn't exist"""
66
+ with pytest.raises(FileNotFoundError):
67
+ load_extern_object("/nonexistent/path.py", "SomeType")
68
+
69
+
70
+ def test_load_extern_object_nonexistent_type():
71
+ """Test behavior when type doesn't exist in the file"""
72
+ with pytest.raises(AttributeError):
73
+ load_extern_object(TEST_MODULE_PATH, "NonExistentType")
74
+
75
+
76
+ def test_load_extern_object_none_path():
77
+ """Test behavior when file path is None"""
78
+ with pytest.raises(AttributeError):
79
+ load_extern_object(None, "SomeType")
80
+
81
+
82
+ def test_load_extern_object_invalid_module():
83
+ """Test behavior when module has syntax errors"""
84
+ # Create a temporary file with syntax errors
85
+ import tempfile
86
+
87
+ with tempfile.NamedTemporaryFile(suffix=".py", mode="w+", delete=False) as temp_file:
88
+ temp_file.write("This is not valid Python syntax :")
89
+ temp_path = temp_file.name
90
+
91
+ try:
92
+ with pytest.raises(RuntimeError):
93
+ load_extern_object(temp_path, "SomeType")
94
+ finally:
95
+ # Clean up the temporary file
96
+ if os.path.exists(temp_path):
97
+ os.remove(temp_path)
code/RL_model/verl/verl_train/tests/utils/test_linear_cross_entropy.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+
18
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
19
+ #
20
+ # Licensed under the Apache License, Version 2.0 (the "License");
21
+ # you may not use this file except in compliance with the License.
22
+ # You may obtain a copy of the License at
23
+ #
24
+ # http://www.apache.org/licenses/LICENSE-2.0
25
+ #
26
+ # Unless required by applicable law or agreed to in writing, software
27
+ # distributed under the License is distributed on an "AS IS" BASIS,
28
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29
+ # See the License for the specific language governing permissions and
30
+ # limitations under the License.
31
+
32
+ import os
33
+
34
+ import torch
35
+
36
+ import verl.utils.torch_functional as verl_F
37
+ from verl.utils.experimental.torch_functional import FusedLinearForPPO
38
+ from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy
39
+ from verl.utils.torch_functional import logprobs_from_logits
40
+
41
+ compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True)
42
+ fused_linear_for_ppo = FusedLinearForPPO()
43
+ fused_linear_for_ppo.compile(dynamic=True)
44
+
45
+ MAX_TEST_CASES = os.environ.get("MAX_TEST_CASES", 5)
46
+
47
+
48
+ def run_torch_entropy(
49
+ hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, temperature: float, reduction="none"
50
+ ) -> list[torch.Tensor]:
51
+ hidden = hidden.squeeze(0).to(torch.float32)
52
+ weight = weight.transpose(0, 1).to(torch.float32)
53
+ logits = torch.matmul(hidden, weight) # [num_tokens, vocab_size]
54
+ logits /= temperature
55
+ pd = torch.nn.functional.softmax(logits, dim=-1) # [num_tokens, vocab_size]
56
+ entropy_a = torch.logsumexp(logits, dim=-1) # [num_tokens]
57
+ entropy_b = torch.sum(pd * logits, dim=-1) # [num_tokens]
58
+ entropy = entropy_a - entropy_b
59
+ logprobs = torch.nn.functional.cross_entropy(logits, labels.squeeze(0), reduction=reduction) # [num_tokens]
60
+ logprobs = torch.neg(logprobs)
61
+ return logprobs, entropy
62
+
63
+
64
+ def run_verl_original_entropy(
65
+ hidden: torch.Tensor,
66
+ weight: torch.Tensor,
67
+ labels: torch.Tensor,
68
+ temperature: float,
69
+ ) -> list[torch.Tensor]:
70
+ hidden = hidden.squeeze(0).to(torch.float32)
71
+ weight = weight.transpose(0, 1).to(torch.float32)
72
+ logits = torch.matmul(hidden, weight) # [num_tokens, vocab_size]
73
+ logits /= temperature
74
+ # compute entropy
75
+ entropy = compute_entropy_from_logits(logits) # ((total_nnz / sp) + pad)
76
+ # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen)
77
+ logprobs = logprobs_from_logits(logits=logits, labels=labels, inplace_backward=False)
78
+ return logprobs, entropy
79
+
80
+
81
+ # To be tested
82
+ def run_verl_torch_fused_entropy(
83
+ hidden: torch.Tensor,
84
+ weight: torch.Tensor,
85
+ labels: torch.Tensor,
86
+ temperature: float,
87
+ ):
88
+ hidden = hidden.to(torch.float32)
89
+ weight = weight.to(torch.float32)
90
+ logprobs, entropy = fused_linear_for_ppo(
91
+ hidden,
92
+ weight,
93
+ labels,
94
+ temperature=temperature,
95
+ )
96
+ return logprobs.squeeze(0), entropy.squeeze(0)
97
+
98
+
99
+ class TestLinearCrossEntropy:
100
+ def __init__(self, test_case_idx: int, temperature: float = 1.5) -> None:
101
+ self.test_case_idx = test_case_idx
102
+ self.temperature = temperature
103
+
104
+ def cleanup(self):
105
+ torch.cuda.empty_cache()
106
+ torch.cuda.reset_peak_memory_stats()
107
+ import gc
108
+
109
+ gc.collect()
110
+ torch.cuda.synchronize()
111
+
112
+ def generate_hyper(self):
113
+ global MAX_TEST_CASES
114
+
115
+ self.dtype = torch.bfloat16
116
+ if self.test_case_idx == 0:
117
+ self.batch_size = 1
118
+ self.num_tokens = 1937
119
+ self.hidden_size = 3584
120
+ self.vocab_size = 152064
121
+ elif self.test_case_idx == 1:
122
+ self.batch_size = 1
123
+ self.num_tokens = 2169
124
+ self.hidden_size = 896
125
+ self.vocab_size = 151936
126
+ elif self.test_case_idx == 2:
127
+ self.batch_size = 1
128
+ self.num_tokens = 1530
129
+ self.hidden_size = 2048
130
+ self.vocab_size = 32256
131
+ elif self.test_case_idx == 3:
132
+ self.batch_size = 1
133
+ self.num_tokens = 1388
134
+ self.hidden_size = 4096
135
+ self.vocab_size = 102400
136
+ elif self.test_case_idx == 4:
137
+ self.batch_size = 1
138
+ self.num_tokens = 8192
139
+ self.hidden_size = 4096
140
+ self.vocab_size = 102400
141
+ else:
142
+ raise ValueError(f"Invalid test case index: {self.test_case_idx}")
143
+ assert MAX_TEST_CASES <= 5, "MAX_TEST_CASES should be less than or equal to 5."
144
+
145
+ def generate_forward_inputs(self):
146
+ hidden = (
147
+ torch.empty((self.batch_size, self.num_tokens, self.hidden_size), dtype=self.dtype, device="cuda")
148
+ .uniform_(-0.5, 0.5)
149
+ .requires_grad_()
150
+ )
151
+ weight = (
152
+ torch.empty((self.vocab_size, self.hidden_size), dtype=self.dtype, device="cuda")
153
+ .uniform_(-0.5, 0.5)
154
+ .requires_grad_()
155
+ )
156
+ labels = torch.randint(0, self.vocab_size, (self.batch_size, self.num_tokens), device="cuda")
157
+ return hidden, weight, labels
158
+
159
+ def generate_backward_inputs(self):
160
+ g_entropy = torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-0.5, 0.5)
161
+ g_logprobs = torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-1, 1)
162
+ return g_entropy, g_logprobs
163
+
164
+ def verify_correctness(self, iterations=5):
165
+ self.cleanup()
166
+ self.generate_hyper()
167
+
168
+ torch_forward_latency = list()
169
+ torch_backward_latency = list()
170
+ verl_forward_latency = list()
171
+ verl_backward_latency = list()
172
+ verl_fused_forward_latency = list()
173
+ verl_fused_backward_latency = list()
174
+ kernel_forward_latency = list()
175
+ kernel_backward_latency = list()
176
+
177
+ start_event = torch.cuda.Event(enable_timing=True)
178
+ end_event = torch.cuda.Event(enable_timing=True)
179
+
180
+ for i in range(iterations):
181
+ print(f"[INFO]: Iteration {i + 1} / {iterations}...", end="\r")
182
+ hidden, weight, labels = self.generate_forward_inputs()
183
+
184
+ start_event.record()
185
+ (torch_logprobs, torch_entropy) = run_torch_entropy(hidden, weight, labels, self.temperature)
186
+ end_event.record()
187
+ torch.cuda.synchronize()
188
+ torch_forward_latency.append(start_event.elapsed_time(end_event))
189
+
190
+ start_event.record()
191
+ (verl_logprobs, verl_entropy) = run_verl_original_entropy(hidden, weight, labels, self.temperature)
192
+ end_event.record()
193
+ torch.cuda.synchronize()
194
+ verl_forward_latency.append(start_event.elapsed_time(end_event))
195
+
196
+ start_event.record()
197
+ (verl_fused_logprobs, verl_fused_entropy) = run_verl_torch_fused_entropy(
198
+ hidden, weight, labels, self.temperature
199
+ )
200
+ end_event.record()
201
+ torch.cuda.synchronize()
202
+ verl_fused_forward_latency.append(start_event.elapsed_time(end_event))
203
+
204
+ start_event.record()
205
+ (kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels, self.temperature)
206
+ end_event.record()
207
+ torch.cuda.synchronize()
208
+ kernel_forward_latency.append(start_event.elapsed_time(end_event))
209
+
210
+ torch.testing.assert_close(torch_logprobs, verl_logprobs, atol=1e-4, rtol=1e-4)
211
+ torch.testing.assert_close(torch_entropy, verl_entropy, atol=1e-4, rtol=1e-4)
212
+
213
+ torch.testing.assert_close(torch_logprobs, verl_fused_logprobs, atol=1e-4, rtol=1e-4)
214
+ torch.testing.assert_close(torch_entropy, verl_fused_entropy, atol=1e-4, rtol=1e-4)
215
+ torch.testing.assert_close(verl_logprobs, verl_fused_logprobs, atol=1e-4, rtol=1e-4)
216
+ torch.testing.assert_close(verl_entropy, verl_fused_entropy, atol=1e-4, rtol=1e-4)
217
+
218
+ torch.testing.assert_close(torch_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4)
219
+ torch.testing.assert_close(torch_entropy, kernel_entropy, atol=5e-3, rtol=5e-4)
220
+ torch.testing.assert_close(verl_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4)
221
+ torch.testing.assert_close(verl_entropy, kernel_entropy, atol=5e-3, rtol=5e-4)
222
+ torch.testing.assert_close(verl_fused_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4)
223
+ torch.testing.assert_close(verl_fused_entropy, kernel_entropy, atol=5e-3, rtol=5e-4)
224
+
225
+ # backward
226
+ g_entropy, g_logprobs = self.generate_backward_inputs()
227
+
228
+ start_event.record()
229
+ (d_torch_hidden, d_torch_weight) = torch.autograd.grad(
230
+ (torch_entropy, torch_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
231
+ )
232
+ end_event.record()
233
+ torch.cuda.synchronize()
234
+ torch_backward_latency.append(start_event.elapsed_time(end_event))
235
+
236
+ start_event.record()
237
+ (d_verl_hidden, d_verl_weight) = torch.autograd.grad(
238
+ (verl_entropy, verl_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
239
+ )
240
+ end_event.record()
241
+ torch.cuda.synchronize()
242
+ verl_backward_latency.append(start_event.elapsed_time(end_event))
243
+
244
+ start_event.record()
245
+ (d_verl_fused_hidden, d_verl_fused_weight) = torch.autograd.grad(
246
+ (verl_fused_entropy, verl_fused_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
247
+ )
248
+ end_event.record()
249
+ torch.cuda.synchronize()
250
+ verl_fused_backward_latency.append(start_event.elapsed_time(end_event))
251
+
252
+ start_event.record()
253
+ (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad(
254
+ (kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
255
+ )
256
+ end_event.record()
257
+ torch.cuda.synchronize()
258
+ kernel_backward_latency.append(start_event.elapsed_time(end_event))
259
+
260
+ torch.testing.assert_close(d_torch_hidden, d_verl_hidden, atol=1e-2, rtol=1e-4)
261
+ torch.testing.assert_close(d_torch_weight, d_verl_weight, atol=1e-2, rtol=1e-4)
262
+
263
+ torch.testing.assert_close(d_torch_hidden, d_verl_fused_hidden, atol=1e-2, rtol=1e-4)
264
+ torch.testing.assert_close(d_torch_weight, d_verl_fused_weight, atol=1e-2, rtol=1e-4)
265
+ torch.testing.assert_close(d_verl_hidden, d_verl_fused_hidden, atol=1e-2, rtol=1e-4)
266
+ torch.testing.assert_close(d_verl_weight, d_verl_fused_weight, atol=1e-2, rtol=1e-4)
267
+ torch.testing.assert_close(d_torch_hidden, d_verl_hidden, atol=1e-2, rtol=1e-4)
268
+ torch.testing.assert_close(d_torch_weight, d_verl_weight, atol=1e-2, rtol=1e-4)
269
+
270
+ torch.testing.assert_close(d_torch_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2)
271
+ torch.testing.assert_close(d_torch_weight, d_kernel_weight, atol=2e-2, rtol=4e-2)
272
+ torch.testing.assert_close(d_verl_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2)
273
+ torch.testing.assert_close(d_verl_weight, d_kernel_weight, atol=2e-2, rtol=4e-2)
274
+ torch.testing.assert_close(d_verl_fused_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2)
275
+ torch.testing.assert_close(d_verl_fused_weight, d_kernel_weight, atol=2e-2, rtol=4e-2)
276
+
277
+ # remove first latency
278
+ torch_forward_latency = torch_forward_latency[1:]
279
+ torch_backward_latency = torch_backward_latency[1:]
280
+ verl_forward_latency = verl_forward_latency[1:]
281
+ verl_backward_latency = verl_backward_latency[1:]
282
+ verl_fused_forward_latency = verl_fused_forward_latency[1:]
283
+ verl_fused_backward_latency = verl_fused_backward_latency[1:]
284
+ kernel_forward_latency = kernel_forward_latency[1:]
285
+ kernel_backward_latency = kernel_backward_latency[1:]
286
+
287
+ print("\n[INFO]: Verified forward & backward correctness.")
288
+
289
+ print(
290
+ f"[INFO]: Forward pass: Torch implementation average time: "
291
+ f"{sum(torch_forward_latency) / len(torch_forward_latency):.2f} ms"
292
+ )
293
+ print(
294
+ f"[INFO]: Backward pass: torch implementation average time: "
295
+ f"{sum(torch_backward_latency) / len(torch_backward_latency):.2f} ms"
296
+ )
297
+ print(
298
+ f"[INFO]: Forward pass: VeRL implementation average time: "
299
+ f"{sum(verl_forward_latency) / len(verl_forward_latency):.2f} ms"
300
+ )
301
+ print(
302
+ f"[INFO]: Backward pass: VeRL implementation average time: "
303
+ f"{sum(verl_backward_latency) / len(verl_backward_latency):.2f} ms"
304
+ )
305
+ print(
306
+ f"[INFO]: Forward pass: VeRL Fused Entropy implementation average time: "
307
+ f"{sum(verl_fused_forward_latency) / len(verl_fused_forward_latency):.2f} ms"
308
+ )
309
+ print(
310
+ f"[INFO]: Backward pass: VeRL Fused Entropy implementation average time: "
311
+ f"{sum(verl_fused_backward_latency) / len(verl_fused_backward_latency):.2f} ms"
312
+ )
313
+ print(
314
+ f"[INFO]: Forward pass: Kernel implementation average time: "
315
+ f"{sum(kernel_forward_latency) / len(kernel_forward_latency):.2f} ms"
316
+ )
317
+ print(
318
+ f"[INFO]: Backward pass: kernel implementation average time: "
319
+ f"{sum(kernel_backward_latency) / len(kernel_backward_latency):.2f} ms"
320
+ )
321
+
322
+ def check_storage(self, method_name, run_forward):
323
+ self.cleanup()
324
+ self.generate_hyper()
325
+
326
+ hidden, weight, labels = self.generate_forward_inputs()
327
+
328
+ torch.cuda.reset_peak_memory_stats()
329
+ (logprobs, entropy) = run_forward(hidden, weight, labels, self.temperature)
330
+ torch.cuda.synchronize()
331
+ torch_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024
332
+ print(f"[INFO]: {method_name} Forward pass peak memory: {torch_max_memory:.2f} MB")
333
+
334
+ g_entropy, g_logprobs = self.generate_backward_inputs()
335
+
336
+ torch.cuda.reset_peak_memory_stats()
337
+ (d_torch_hidden, d_torch_weight) = torch.autograd.grad(
338
+ (entropy, logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
339
+ )
340
+ torch.cuda.synchronize()
341
+ torch_backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024
342
+ print(f"[INFO]: {method_name} Backward pass peak memory: {torch_backward_max_memory:.2f} MB")
343
+
344
+ def check_storage_all(self):
345
+ self.check_storage("Torch", run_torch_entropy)
346
+ self.check_storage("VeRL", run_verl_original_entropy)
347
+ self.check_storage("VeRL Torch Fused", run_verl_torch_fused_entropy)
348
+ self.check_storage("Kernel", linear_cross_entropy)
349
+
350
+
351
+ if __name__ == "__main__":
352
+ # torch.cuda.memory._record_memory_history()
353
+
354
+ for test_case_idx in range(MAX_TEST_CASES):
355
+ print(f"[INFO] Running test case {test_case_idx}")
356
+ test = TestLinearCrossEntropy(test_case_idx)
357
+
358
+ test.verify_correctness()
359
+ test.check_storage_all()
360
+
361
+ # torch.cuda.memory._dump_snapshot("test_linear_cross_entropy.pkl")
code/RL_model/verl/verl_train/tests/utils/test_mlflow_key_sanitization.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import unittest
16
+ from unittest.mock import patch
17
+
18
+ from verl.utils.tracking import _MlflowLoggingAdapter
19
+
20
+
21
+ class TestMlflowLoggingAdapter(unittest.TestCase):
22
+ def test_sanitize_key_and_warning(self):
23
+ """Test key sanitization for invalid characters and consecutive slashes with warnings."""
24
+ adapter = _MlflowLoggingAdapter()
25
+ data = {
26
+ "valid_key": 1.0,
27
+ "invalid@key!": 2.0,
28
+ "another/valid-key": 3.0,
29
+ "bad key#": 4.0,
30
+ "val-aux//reward/mean_at_1": 5.0,
31
+ "val-core///acc/best_at_5": 6.0,
32
+ "metric////with/many////slashes": 7.0,
33
+ }
34
+ # Patch mlflow.log_metrics to capture the metrics actually sent
35
+ with (
36
+ patch("mlflow.log_metrics") as mock_log_metrics,
37
+ patch.object(adapter, "logger") as mock_logger,
38
+ ):
39
+ adapter.log(data, step=5)
40
+ # Check that invalid characters are sanitized
41
+ sent_metrics = mock_log_metrics.call_args[1]["metrics"]
42
+ self.assertIn("invalid_at_key_", sent_metrics) # @ becomes _at_, ! becomes _
43
+ self.assertIn("bad key_", sent_metrics) # # becomes _, space remains
44
+ self.assertNotIn("invalid@key!", sent_metrics)
45
+ self.assertNotIn("bad key#", sent_metrics)
46
+ # Check that consecutive slashes are collapsed to single slashes
47
+ self.assertIn("val-aux/reward/mean_at_1", sent_metrics)
48
+ self.assertIn("val-core/acc/best_at_5", sent_metrics)
49
+ self.assertIn("metric/with/many/slashes", sent_metrics)
50
+ self.assertNotIn("val-aux//reward/mean_at_1", sent_metrics)
51
+ self.assertNotIn("val-core///acc/best_at_5", sent_metrics)
52
+ # Check that warnings were logged for all sanitized keys
53
+ warning_msgs = [str(call) for call in mock_logger.warning.call_args_list]
54
+ # Warnings for invalid characters
55
+ self.assertTrue(any("invalid@key!" in msg and "invalid_at_key_" in msg for msg in warning_msgs))
56
+ self.assertTrue(any("bad key#" in msg and "bad key_" in msg for msg in warning_msgs))
57
+ # Warnings for consecutive slashes
58
+ self.assertTrue(any("val-aux//reward/mean_at_1" in msg for msg in warning_msgs))
59
+ self.assertTrue(any("val-core///acc/best_at_5" in msg for msg in warning_msgs))
60
+ self.assertTrue(any("metric////with/many////slashes" in msg for msg in warning_msgs))
61
+
62
+
63
+ if __name__ == "__main__":
64
+ unittest.main()
code/RL_model/verl/verl_train/tests/utils/test_model_on_cpu.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from types import SimpleNamespace # Or use a mock object library
16
+
17
+ import pytest
18
+
19
+ from verl.utils.model import update_model_config
20
+
21
+
22
+ # Parametrize with different override scenarios
23
+ @pytest.mark.parametrize(
24
+ "override_kwargs",
25
+ [
26
+ {"param_a": 5, "new_param": "plain_added"},
27
+ {"param_a": 2, "nested_params": {"sub_param_x": "updated_x", "sub_param_z": True}},
28
+ ],
29
+ )
30
+ def test_update_model_config(override_kwargs):
31
+ """
32
+ Tests that update_model_config correctly updates attributes,
33
+ handling both plain and nested overrides via parametrization.
34
+ """
35
+ # Create a fresh mock config object for each test case
36
+ mock_config = SimpleNamespace(
37
+ param_a=1, nested_params=SimpleNamespace(sub_param_x="original_x", sub_param_y=100), other_param="keep_me"
38
+ )
39
+ # Apply the updates using the parametrized override_kwargs
40
+ update_model_config(mock_config, override_kwargs)
41
+
42
+ # Assertions to check if the config was updated correctly
43
+ if "nested_params" in override_kwargs: # Case 2: Nested override
44
+ override_nested = override_kwargs["nested_params"]
45
+ assert mock_config.nested_params.sub_param_x == override_nested["sub_param_x"], "Nested sub_param_x mismatch"
46
+ assert mock_config.nested_params.sub_param_y == 100, "Nested sub_param_y should be unchanged"
47
+ assert hasattr(mock_config.nested_params, "sub_param_z"), "Expected nested sub_param_z to be added"
48
+ assert mock_config.nested_params.sub_param_z == override_nested["sub_param_z"], "Value of sub_param_z mismatch"
49
+ else: # Case 1: Plain override (nested params untouched)
50
+ assert mock_config.nested_params.sub_param_x == "original_x", "Nested sub_param_x should be unchanged"
51
+ assert mock_config.nested_params.sub_param_y == 100, "Nested sub_param_y should be unchanged"
52
+ assert not hasattr(mock_config.nested_params, "sub_param_z"), "Nested sub_param_z should not exist"
code/RL_model/verl/verl_train/tests/utils/test_nvtx_profile.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+ from unittest.mock import MagicMock, patch
18
+
19
+ from verl.utils import omega_conf_to_dataclass
20
+ from verl.utils.profiler.config import NsightToolConfig, ProfilerConfig
21
+ from verl.utils.profiler.profile import DistProfiler
22
+
23
+
24
+ class TestProfilerConfig(unittest.TestCase):
25
+ def test_config_init(self):
26
+ import os
27
+
28
+ from hydra import compose, initialize_config_dir
29
+
30
+ with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
31
+ cfg = compose(config_name="ppo_trainer")
32
+ for config in [
33
+ cfg.actor_rollout_ref.actor.profiler,
34
+ cfg.actor_rollout_ref.rollout.profiler,
35
+ cfg.actor_rollout_ref.ref.profiler,
36
+ cfg.critic.profiler,
37
+ cfg.reward_model.profiler,
38
+ ]:
39
+ profiler_config = omega_conf_to_dataclass(config)
40
+ self.assertEqual(profiler_config.tool, config.tool)
41
+ self.assertEqual(profiler_config.enable, config.enable)
42
+ self.assertEqual(profiler_config.all_ranks, config.all_ranks)
43
+ self.assertEqual(profiler_config.ranks, config.ranks)
44
+ self.assertEqual(profiler_config.save_path, config.save_path)
45
+ self.assertEqual(profiler_config.ranks, config.ranks)
46
+ assert isinstance(profiler_config, ProfilerConfig)
47
+ with self.assertRaises(AttributeError):
48
+ _ = profiler_config.non_existing_key
49
+ assert config.get("non_existing_key") == profiler_config.get("non_existing_key")
50
+ assert config.get("non_existing_key", 1) == profiler_config.get("non_existing_key", 1)
51
+
52
+ def test_frozen_config(self):
53
+ """Test that modifying frozen keys in ProfilerConfig raises exceptions."""
54
+ from dataclasses import FrozenInstanceError
55
+
56
+ from verl.utils.profiler.config import ProfilerConfig
57
+
58
+ # Create a new ProfilerConfig instance
59
+ config = ProfilerConfig(all_ranks=False, ranks=[0])
60
+
61
+ with self.assertRaises(FrozenInstanceError):
62
+ config.all_ranks = True
63
+
64
+ with self.assertRaises(FrozenInstanceError):
65
+ config.ranks = [1, 2, 3]
66
+
67
+ with self.assertRaises(TypeError):
68
+ config["all_ranks"] = True
69
+
70
+ with self.assertRaises(TypeError):
71
+ config["ranks"] = [1, 2, 3]
72
+
73
+
74
+ class TestNsightSystemsProfiler(unittest.TestCase):
75
+ """Test suite for NsightSystemsProfiler functionality.
76
+
77
+ Test Plan:
78
+ 1. Initialization: Verify profiler state after creation
79
+ 2. Basic Profiling: Test start/stop functionality
80
+ 3. Discrete Mode: TODO: Test discrete profiling behavior
81
+ 4. Annotation: Test the annotate decorator in both normal and discrete modes
82
+ 5. Config Validation: Verify proper config initialization from OmegaConf
83
+ """
84
+
85
+ def setUp(self):
86
+ self.config = ProfilerConfig(tool="nsys", enable=True, all_ranks=True)
87
+ self.rank = 0
88
+ self.profiler = DistProfiler(self.rank, self.config, tool_config=NsightToolConfig(discrete=False))
89
+
90
+ def test_initialization(self):
91
+ self.assertEqual(self.profiler.check_this_rank(), True)
92
+ self.assertEqual(self.profiler.check_this_step(), False)
93
+
94
+ def test_start_stop_profiling(self):
95
+ with patch("torch.cuda.profiler.start") as mock_start, patch("torch.cuda.profiler.stop") as mock_stop:
96
+ # Test start
97
+ self.profiler.start()
98
+ self.assertTrue(self.profiler.check_this_step())
99
+ mock_start.assert_called_once()
100
+
101
+ # Test stop
102
+ self.profiler.stop()
103
+ self.assertFalse(self.profiler.check_this_step())
104
+ mock_stop.assert_called_once()
105
+
106
+ # def test_discrete_profiling(self):
107
+ # discrete_config = ProfilerConfig(discrete=True, all_ranks=True)
108
+ # profiler = NsightSystemsProfiler(self.rank, discrete_config)
109
+
110
+ # with patch("torch.cuda.profiler.start") as mock_start, patch("torch.cuda.profiler.stop") as mock_stop:
111
+ # profiler.start()
112
+ # self.assertTrue(profiler.this_step)
113
+ # mock_start.assert_not_called() # Shouldn't start immediately in discrete mode
114
+
115
+ # profiler.stop()
116
+ # self.assertFalse(profiler.this_step)
117
+ # mock_stop.assert_not_called() # Shouldn't stop immediately in discrete mode
118
+
119
+ def test_annotate_decorator(self):
120
+ mock_self = MagicMock()
121
+ mock_self.profiler = self.profiler
122
+ mock_self.profiler.start()
123
+ decorator = mock_self.profiler.annotate(message="test")
124
+
125
+ @decorator
126
+ def test_func(self, *args, **kwargs):
127
+ return "result"
128
+
129
+ with (
130
+ patch("torch.cuda.profiler.start") as mock_start,
131
+ patch("torch.cuda.profiler.stop") as mock_stop,
132
+ patch("verl.utils.profiler.nvtx_profile.mark_start_range") as mock_start_range,
133
+ patch("verl.utils.profiler.nvtx_profile.mark_end_range") as mock_end_range,
134
+ ):
135
+ result = test_func(mock_self)
136
+ self.assertEqual(result, "result")
137
+ mock_start_range.assert_called_once()
138
+ mock_end_range.assert_called_once()
139
+ mock_start.assert_not_called() # Not discrete mode
140
+ mock_stop.assert_not_called() # Not discrete mode
141
+
142
+ # def test_annotate_discrete_mode(self):
143
+ # discrete_config = ProfilerConfig(discrete=True, all_ranks=True)
144
+ # profiler = NsightSystemsProfiler(self.rank, discrete_config)
145
+ # mock_self = MagicMock()
146
+ # mock_self.profiler = profiler
147
+ # mock_self.profiler.this_step = True
148
+
149
+ # @NsightSystemsProfiler.annotate(message="test")
150
+ # def test_func(self, *args, **kwargs):
151
+ # return "result"
152
+
153
+ # with (
154
+ # patch("torch.cuda.profiler.start") as mock_start,
155
+ # patch("torch.cuda.profiler.stop") as mock_stop,
156
+ # patch("verl.utils.profiler.nvtx_profile.mark_start_range") as mock_start_range,
157
+ # patch("verl.utils.profiler.nvtx_profile.mark_end_range") as mock_end_range,
158
+ # ):
159
+ # result = test_func(mock_self)
160
+ # self.assertEqual(result, "result")
161
+ # mock_start_range.assert_called_once()
162
+ # mock_end_range.assert_called_once()
163
+ # mock_start.assert_called_once() # Should start in discrete mode
164
+ # mock_stop.assert_called_once() # Should stop in discrete mode
165
+
166
+
167
+ if __name__ == "__main__":
168
+ unittest.main()
code/RL_model/verl/verl_train/tests/utils/test_rollout_skip_on_cpu.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import shutil
15
+ import tempfile
16
+ from pathlib import Path
17
+ from unittest.mock import MagicMock
18
+
19
+ import pytest
20
+ import torch
21
+
22
+ from verl.utils.rollout_skip import DataProto, RolloutSkip
23
+
24
+ len_prompt = 50
25
+ len_response = 100
26
+
27
+
28
+ def temp_dir():
29
+ # Create a temporary directory
30
+ temp_dir = Path(tempfile.mkdtemp())
31
+ yield temp_dir
32
+ # Cleanup
33
+ shutil.rmtree(temp_dir)
34
+
35
+
36
+ def build_generate_fn(gen_bs, n):
37
+ len_tokenizer = 1024
38
+
39
+ def iterate():
40
+ while True:
41
+ prompt = torch.randint(len_tokenizer, size=(gen_bs, len_prompt)).repeat_interleave(n, dim=0)
42
+ generate = torch.randint(len_tokenizer, size=(gen_bs * n, len_response))
43
+ data = DataProto.from_dict(tensors={"prompt": prompt, "response": generate})
44
+ yield data
45
+
46
+ mock_infer_engine = iterate()
47
+
48
+ def fn(batch, **kwargs):
49
+ # Simulate the inference engine returning the next batch
50
+ return next(mock_infer_engine)
51
+
52
+ return fn
53
+
54
+
55
+ @pytest.fixture(params=[(32, 4), (64, 4), (64, 8)])
56
+ def mock_rollout_wg(request):
57
+ gen_bs, n = request.param
58
+ rollout_wg = MagicMock()
59
+
60
+ config = MagicMock()
61
+ config.actor_rollout_ref.rollout = {
62
+ "n": n,
63
+ "skip_dump_dir": next(temp_dir()),
64
+ }
65
+ config.data = {"gen_batch_size": gen_bs}
66
+
67
+ rollout_wg.generate_sequences = build_generate_fn(gen_bs, n)
68
+
69
+ yield config, rollout_wg
70
+ # Cleanup
71
+ shutil.rmtree(next(temp_dir()))
72
+
73
+
74
+ class TestRolloutSkip:
75
+ def test_initialization(self, capsys):
76
+ """Test that RolloutSkip initializes correctly"""
77
+ config = MagicMock()
78
+ config.actor_rollout_ref.rollout = {
79
+ "n": 16,
80
+ "skip_dump_dir": "tmp/rollout_dump",
81
+ }
82
+ config.data = {"gen_batch_size": 128}
83
+ mock_rollout_wg = MagicMock()
84
+ skip = RolloutSkip(config, mock_rollout_wg)
85
+
86
+ assert skip.n == 16
87
+ assert skip.gbs == 128
88
+ assert str(skip.dumped_dir) == "tmp/rollout_dump"
89
+
90
+ assert skip._rollout_wg == mock_rollout_wg
91
+ skip.wrap_generate_sequences()
92
+ captured = capsys.readouterr()
93
+ assert "Successfully patched" in captured.out
94
+
95
+ def test_generate_without_wrap(self, mock_rollout_wg):
96
+ """Test that generate_sequences works without wrapping"""
97
+
98
+ config, rollout_wg = mock_rollout_wg
99
+ _ = RolloutSkip(config, rollout_wg)
100
+
101
+ _result = rollout_wg.generate_sequences(MagicMock())
102
+ for _ in range(10):
103
+ result = rollout_wg.generate_sequences(MagicMock())
104
+ assert isinstance(result, DataProto)
105
+ # * make sure the data is different
106
+ assert torch.abs(_result.batch["prompt"] - result.batch["prompt"]).sum() > 0
107
+ assert torch.abs(_result.batch["response"] - result.batch["response"]).sum() > 0
108
+ _result = result
109
+
110
+ def test_dump(self, mock_rollout_wg, capsys):
111
+ config, rollout_wg = mock_rollout_wg
112
+ skip = RolloutSkip(config, rollout_wg)
113
+ skip.wrap_generate_sequences()
114
+
115
+ result = rollout_wg.generate_sequences(MagicMock())
116
+ # * check if dump is OK
117
+ assert skip.curr_path_dump.exists()
118
+ captured = capsys.readouterr()
119
+ assert "Successfully dump data in" in captured.out
120
+ # * get file size, estimate file size
121
+ file_size = skip.curr_path_dump.stat().st_size
122
+ est_file_size = (len_prompt + len_response) * skip.gbs * skip.n * result.batch["prompt"].dtype.itemsize
123
+ assert file_size >= est_file_size, "Dumped file size is smaller than expected"
124
+
125
+ def test_generate_with_wrap(self, mock_rollout_wg, capsys):
126
+ """Test that generate_sequences works without wrapping"""
127
+
128
+ config, rollout_wg = mock_rollout_wg
129
+ skip = RolloutSkip(config, rollout_wg)
130
+ skip.wrap_generate_sequences()
131
+
132
+ _result = rollout_wg.generate_sequences(MagicMock())
133
+
134
+ for _ in range(10):
135
+ result = rollout_wg.generate_sequences(MagicMock())
136
+ assert isinstance(result, DataProto)
137
+ # * make sure the data is different
138
+ assert torch.abs(_result.batch["prompt"] - result.batch["prompt"]).sum() == 0
139
+ assert torch.abs(_result.batch["response"] - result.batch["response"]).sum() == 0
140
+ captured = capsys.readouterr()
141
+ assert "Successfully load pre-generated data from" in captured.out
142
+ _result = result
code/RL_model/verl/verl_train/tests/utils/test_rollout_trace_on_cpu.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import sys
17
+ from unittest.mock import MagicMock, patch
18
+
19
+ import pytest
20
+
21
+ from verl.utils.rollout_trace import RolloutTraceConfig, rollout_trace_attr, rollout_trace_op
22
+
23
+
24
+ @pytest.fixture(autouse=True)
25
+ def reset_rollout_trace_config_singleton():
26
+ """Fixture to reset the RolloutTraceConfig singleton before each test."""
27
+ RolloutTraceConfig.reset()
28
+
29
+
30
+ @pytest.fixture
31
+ def mock_weave_client():
32
+ """Mocks the weave module and its client, yielding the mock client."""
33
+ mock_weave = MagicMock()
34
+ mock_client = MagicMock()
35
+ mock_call = MagicMock()
36
+ mock_client.create_call.return_value = mock_call
37
+ mock_weave.init.return_value = mock_client
38
+
39
+ # Also mock the call_context if it's used internally by the decorator
40
+ mock_weave.trace.context.call_context.return_value = MagicMock()
41
+
42
+ with patch.dict(sys.modules, {"weave": mock_weave, "weave.trace.context": mock_weave.trace.context}):
43
+ yield mock_client
44
+
45
+
46
+ class TracedClass:
47
+ @rollout_trace_op
48
+ # @weave.op
49
+ # @mlflow.trace
50
+ async def my_method(self, a, b="default"):
51
+ return f"result: {a}, {b}"
52
+
53
+ @rollout_trace_op
54
+ # @weave.op
55
+ # @mlflow.trace
56
+ async def middle_method(self, a, b="default"):
57
+ await self.my_method("test_a1", b="test_b1")
58
+ return f"result: {a}, {b}"
59
+
60
+ @rollout_trace_op
61
+ # @mlflow.trace
62
+ async def my_method_with_exception(self):
63
+ raise ValueError("Test Exception")
64
+
65
+ async def upper_method(self):
66
+ await self.my_method("test_a0", b="test_b0")
67
+ await self.middle_method("test_a2", b="test_b2")
68
+ return True
69
+
70
+
71
+ class UntracedClass:
72
+ @rollout_trace_op
73
+ async def my_method(self, x):
74
+ return x * 2
75
+
76
+
77
+ async def test_rollout_trace_on_untraced_class():
78
+ """Tests that the decorator works correctly when no backend is configured."""
79
+ instance = UntracedClass()
80
+ assert await instance.my_method(10) == 20
81
+
82
+
83
+ async def test_rollout_trace_with_tracer(mock_weave_client):
84
+ """Tests that the decorator calls the tracer's methods correctly."""
85
+ RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="weave")
86
+ instance = TracedClass()
87
+ assert RolloutTraceConfig.get_client() is mock_weave_client
88
+
89
+ result = await instance.my_method("test_a", b="test_b")
90
+
91
+ assert result == "result: test_a, test_b"
92
+ mock_weave_client.create_call.assert_called_once()
93
+ call_kwargs = mock_weave_client.create_call.call_args.kwargs
94
+ assert call_kwargs["op"] == "TracedClass.my_method"
95
+ expected_inputs = {"a": "test_a", "b": "test_b"}
96
+ assert call_kwargs["inputs"] == expected_inputs
97
+
98
+ mock_call = mock_weave_client.create_call.return_value
99
+ mock_weave_client.finish_call.assert_called_once_with(mock_call, output=result)
100
+
101
+
102
+ async def test_rollout_trace_with_exception(mock_weave_client):
103
+ """Tests that `finish` is called with the exception when one is raised."""
104
+ RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="weave")
105
+ instance = TracedClass()
106
+
107
+ with pytest.raises(ValueError, match="Test Exception"):
108
+ await instance.my_method_with_exception()
109
+
110
+ mock_weave_client.create_call.assert_called_once()
111
+ mock_call = mock_weave_client.create_call.return_value
112
+ mock_weave_client.finish_call.assert_called_once()
113
+
114
+ # Check that finish_call was called with the exception
115
+ args, kwargs = mock_weave_client.finish_call.call_args
116
+ assert args[0] == mock_call
117
+ assert "exception" in kwargs
118
+ assert isinstance(kwargs["exception"], ValueError)
119
+
120
+
121
+ async def test_rollout_trace_with_dummy_backend(mock_weave_client):
122
+ """Tests that the tracer is not called when the backend is 'dummy'."""
123
+ RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="dummy")
124
+ instance = TracedClass()
125
+
126
+ await instance.my_method("test_a")
127
+
128
+ mock_weave_client.create_call.assert_not_called()
129
+
130
+
131
+ async def test_trace_disabled_with_trace_false(mock_weave_client):
132
+ """Tests that tracing is disabled when trace=False."""
133
+ RolloutTraceConfig.init(
134
+ project_name="my-project",
135
+ experiment_name="my-experiment",
136
+ backend="weave",
137
+ )
138
+ instance = TracedClass()
139
+
140
+ assert RolloutTraceConfig.get_backend() == "weave"
141
+
142
+ with rollout_trace_attr(step=1, sample_index=0, rollout_n=0, trace=False):
143
+ result = await instance.my_method("test_a", b="test_b")
144
+ assert result == "result: test_a, test_b"
145
+
146
+ # No tracing should have occurred
147
+ mock_weave_client.create_call.assert_not_called()
148
+
149
+ # Verify that tracing works again with trace=True (default)
150
+ with rollout_trace_attr(step=1, sample_index=0, rollout_n=0):
151
+ result = await instance.my_method("test_a", b="test_b")
152
+ assert result == "result: test_a, test_b"
153
+
154
+ assert mock_weave_client.create_call.call_count == 1
155
+
156
+
157
+ async def test_trace_false_disables_nested_trace_ops(mock_weave_client):
158
+ """Tests that trace=False disables all nested @rollout_trace_op calls."""
159
+ RolloutTraceConfig.init(
160
+ project_name="my-project",
161
+ experiment_name="my-experiment",
162
+ backend="weave",
163
+ )
164
+ instance = TracedClass()
165
+
166
+ with rollout_trace_attr(step=1, sample_index=0, rollout_n=0, trace=False):
167
+ # Call upper_method which internally calls my_method and middle_method
168
+ # All of these are decorated with @rollout_trace_op
169
+ result = await instance.upper_method()
170
+ assert result is True
171
+
172
+ # No tracing should have occurred for any of the nested calls
173
+ mock_weave_client.create_call.assert_not_called()
174
+
175
+ with rollout_trace_attr(step=1, sample_index=0, rollout_n=0):
176
+ result = await instance.my_method("test_a", b="test_b")
177
+ assert result == "result: test_a, test_b"
178
+
179
+ assert mock_weave_client.create_call.call_count == 1
180
+
181
+
182
+ async def test_trace_enabled_restored_after_exception(mock_weave_client):
183
+ """Tests that trace state is restored even if an exception occurs when trace=False."""
184
+ RolloutTraceConfig.init(
185
+ project_name="my-project",
186
+ experiment_name="my-experiment",
187
+ backend="weave",
188
+ )
189
+ instance = TracedClass()
190
+
191
+ assert RolloutTraceConfig.get_backend() == "weave"
192
+
193
+ # Use trace=False and raise an exception
194
+ try:
195
+ with rollout_trace_attr(step=1, sample_index=0, rollout_n=0, trace=False):
196
+ raise RuntimeError("Test exception with trace disabled")
197
+ except RuntimeError:
198
+ pass
199
+
200
+ with rollout_trace_attr(step=1, sample_index=0, rollout_n=0):
201
+ result = await instance.my_method("test_a", b="test_b")
202
+ assert result == "result: test_a, test_b"
203
+
204
+ assert mock_weave_client.create_call.call_count == 1
205
+
206
+
207
+ @pytest.mark.skipif(
208
+ os.environ.get("RUN_WEAVE_INTEGRATION_TESTS", "false").lower() != "true",
209
+ reason="Skipping weave integration test. Set RUN_WEAVE_INTEGRATION_TESTS=true to run.",
210
+ )
211
+ async def test_rollout_trace_with_real_weave_backend():
212
+ """Integration test with a real weave backend."""
213
+
214
+ # This assumes that the weave environment (e.g., project) is configured
215
+ RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="weave")
216
+
217
+ instance = TracedClass()
218
+
219
+ with rollout_trace_attr(step=1, sample_index=2, rollout_n=3):
220
+ await instance.upper_method()
221
+
222
+ with pytest.raises(ValueError, match="Test Exception"):
223
+ await instance.my_method_with_exception()
224
+
225
+ print("\nWeave integration test ran successfully. Check your weave project for the trace.")
226
+
227
+
228
+ @pytest.mark.skipif(
229
+ os.environ.get("RUN_MLFLOW_INTEGRATION_TESTS", "false").lower() != "true",
230
+ reason="Skipping mlflow integration test. Set RUN_MLFLOW_INTEGRATION_TESTS=true to run.",
231
+ )
232
+ async def test_rollout_trace_with_real_mlflow_backend():
233
+ """Integration test with a real mlflow backend."""
234
+
235
+ # This assumes that the mlflow environment (e.g., project) is configured
236
+ RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="mlflow")
237
+
238
+ instance = TracedClass()
239
+
240
+ with rollout_trace_attr(step=1, sample_index=2, rollout_n=3, name="agent_run"):
241
+ assert await instance.upper_method()
242
+
243
+ # with pytest.raises(ValueError, match="Test Exception"):
244
+ # await instance.my_method_with_exception()
245
+
246
+ print("\nWeave integration test ran successfully. Check your weave project for the trace.")
code/RL_model/verl/verl_train/tests/utils/test_seqlen_balancing.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.distributed as dist
17
+ import torch.multiprocessing as mp
18
+
19
+ from verl import DataProto
20
+ from verl.utils.device import get_device_name, get_nccl_backend, get_torch_device
21
+ from verl.utils.model import create_random_mask
22
+ from verl.utils.seqlen_balancing import (
23
+ ceildiv,
24
+ get_reverse_idx,
25
+ prepare_dynamic_batch,
26
+ rearrange_micro_batches,
27
+ restore_dynamic_batch,
28
+ )
29
+
30
+
31
+ def test_seqlen_balancing():
32
+ input_ids = torch.randint(low=0, high=10, size=(20, 100))
33
+
34
+ attention_mask = create_random_mask(
35
+ input_ids=input_ids, max_ratio_of_left_padding=0.1, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.5
36
+ )
37
+ data = {"input_ids": input_ids, "attention_mask": attention_mask}
38
+ dataproto = DataProto.from_single_dict(data)
39
+ micro_batches, micro_bsz_idx_lst = rearrange_micro_batches(dataproto.batch, max_token_len=300)
40
+ batch = torch.cat(micro_batches)
41
+ micro_bsz_idx = []
42
+ for idx in micro_bsz_idx_lst:
43
+ micro_bsz_idx.extend(idx)
44
+ reverse_idx_map = get_reverse_idx(micro_bsz_idx)
45
+ reverse_idx_map = torch.tensor(reverse_idx_map)
46
+ new_batch = batch[reverse_idx_map]
47
+ torch.testing.assert_close(new_batch, dataproto.batch)
48
+
49
+
50
+ def test_dynamic_batch():
51
+ input_ids = torch.randint(low=0, high=10, size=(20, 100))
52
+
53
+ attention_mask = create_random_mask(
54
+ input_ids=input_ids, max_ratio_of_left_padding=0.1, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.5
55
+ )
56
+ data = {"input_ids": input_ids, "attention_mask": attention_mask}
57
+ dataproto = DataProto.from_single_dict(data)
58
+ micro_batches, micro_bsz_idx_lst = prepare_dynamic_batch(dataproto, max_token_len=300)
59
+ input_ids = torch.cat([micro_batch.batch["input_ids"] for micro_batch in micro_batches], dim=0)
60
+ input_ids = restore_dynamic_batch(input_ids, micro_bsz_idx_lst)
61
+ torch.testing.assert_close(input_ids, dataproto.batch["input_ids"])
62
+
63
+
64
+ def _worker(rank, world_size, init_method, max_token_len, use_same_dp, min_mb):
65
+ # 1) init process group & CUDA
66
+ get_torch_device().set_device(rank)
67
+ dist.init_process_group(
68
+ backend=get_nccl_backend(),
69
+ init_method=init_method,
70
+ world_size=world_size,
71
+ rank=rank,
72
+ )
73
+
74
+ # 2) build a small random batch (each rank different length to force mismatch)
75
+ torch.manual_seed(42 + rank)
76
+ input_ids = torch.randint(0, 10, (20 + rank * 5, 100), device=f"{get_device_name()}:{rank}")
77
+ attention_mask = create_random_mask(
78
+ input_ids=input_ids,
79
+ max_ratio_of_left_padding=0.1,
80
+ max_ratio_of_valid_token=0.9,
81
+ min_ratio_of_valid_token=0.5,
82
+ )
83
+ dp = {"input_ids": input_ids, "attention_mask": attention_mask}
84
+ proto = DataProto.from_single_dict(dp)
85
+ batch = proto.batch
86
+
87
+ # 3) call rearrange_micro_batches with one of the two params under test
88
+ micros, idx_lst = rearrange_micro_batches(
89
+ batch,
90
+ max_token_len=max_token_len,
91
+ dp_group=dist.group.WORLD,
92
+ same_micro_num_in_dp=use_same_dp,
93
+ min_num_micro_batch=min_mb,
94
+ )
95
+
96
+ # 4) check the enforced counts
97
+ seq_len_effective: torch.Tensor = batch["attention_mask"].sum(dim=1)
98
+ total_seqlen = seq_len_effective.sum().item()
99
+ local = min(len(seq_len_effective), ceildiv(total_seqlen, max_token_len))
100
+
101
+ if min_mb is not None:
102
+ expected = max(local, min_mb)
103
+ assert len(micros) == expected
104
+ if use_same_dp:
105
+ # gather all local_counts
106
+ counts = [torch.zeros(1, device=f"{get_device_name()}:{rank}") for _ in range(world_size)]
107
+ counts[rank].fill_(local)
108
+ dist.all_gather(counts, counts[rank])
109
+ expected = max(int(c.item()) for c in counts)
110
+ assert len(micros) == expected
111
+ else:
112
+ # if neither, we get the local natural count
113
+ assert len(micros) == local
114
+
115
+ # 5) reconstruction sanity: concat→reverse_idx→orig
116
+ flat = torch.cat(micros, dim=0)
117
+ idx = []
118
+ for sub in idx_lst:
119
+ idx.extend(sub)
120
+ inv = get_reverse_idx(idx)
121
+ inv = torch.tensor(inv, device=flat.device)
122
+ reconstructed = flat[inv]
123
+ torch.testing.assert_close(reconstructed, batch)
124
+
125
+ dist.destroy_process_group()
126
+
127
+
128
+ def test_dataproto_split_uneven():
129
+ """Test DataProto.split with uneven splits"""
130
+ # Create test data with 10 items
131
+ input_ids = torch.randint(low=0, high=10, size=(10, 5))
132
+ attention_mask = torch.ones(10, 5)
133
+ data = {"input_ids": input_ids, "attention_mask": attention_mask}
134
+ dataproto = DataProto.from_single_dict(data)
135
+
136
+ # Test split with size 3 (should create chunks of [3, 3, 3, 1])
137
+ splits = dataproto.split(3)
138
+ assert len(splits) == 4
139
+ assert len(splits[0]) == 3
140
+ assert len(splits[1]) == 3
141
+ assert len(splits[2]) == 3
142
+ assert len(splits[3]) == 1
143
+
144
+ reconstructed = DataProto.concat(splits)
145
+ torch.testing.assert_close(reconstructed.batch["input_ids"], dataproto.batch["input_ids"])
146
+ torch.testing.assert_close(reconstructed.batch["attention_mask"], dataproto.batch["attention_mask"])
147
+
148
+ # Test split with size equal to length (should create one chunk)
149
+ splits = dataproto.split(10)
150
+ assert len(splits) == 1
151
+ assert len(splits[0]) == 10
152
+
153
+ # Test split with size larger than length (should create one chunk with all data)
154
+ splits = dataproto.split(15)
155
+ assert len(splits) == 1
156
+ assert len(splits[0]) == 10
157
+
158
+ # Test with non-tensor batch data
159
+ import numpy as np
160
+
161
+ data_with_non_tensor = {
162
+ "input_ids": input_ids,
163
+ "attention_mask": attention_mask,
164
+ "labels": np.array([f"label_{i}" for i in range(10)], dtype=object),
165
+ }
166
+ dataproto_with_non_tensor = DataProto.from_single_dict(data_with_non_tensor)
167
+
168
+ splits = dataproto_with_non_tensor.split(3)
169
+ assert len(splits) == 4
170
+ assert len(splits[0]) == 3
171
+ assert len(splits[1]) == 3
172
+ assert len(splits[2]) == 3
173
+ assert len(splits[3]) == 1
174
+
175
+ # Verify non-tensor data integrity
176
+ reconstructed = DataProto.concat(splits)
177
+ np.testing.assert_array_equal(
178
+ reconstructed.non_tensor_batch["labels"], dataproto_with_non_tensor.non_tensor_batch["labels"]
179
+ )
180
+
181
+
182
+ def test_seqlen_balancing_distributed_params(tmp_path):
183
+ world_size = 2
184
+ init_file = tmp_path / "dist_init"
185
+ init_file.write_text("") # empty file
186
+ init_method = f"file://{init_file}"
187
+
188
+ # test min_num_micro_batch only
189
+ mp.spawn(
190
+ _worker,
191
+ args=(world_size, init_method, 300, False, 4),
192
+ nprocs=world_size,
193
+ join=True,
194
+ )
195
+
196
+ # test same_micro_num_in_dp only
197
+ mp.spawn(
198
+ _worker,
199
+ args=(world_size, init_method, 300, True, None),
200
+ nprocs=world_size,
201
+ join=True,
202
+ )
203
+
204
+
205
+ def test_group_balanced_partitions():
206
+ """Test group-level balancing keeps same-uid samples together."""
207
+ from verl.utils.seqlen_balancing import get_group_balanced_partitions
208
+
209
+ # Create test data: 4 groups with different sizes
210
+ # Group 0 (uid=0): indices 0,1,2,3 with seqlens [100, 100, 100, 100]
211
+ # Group 1 (uid=1): indices 4,5,6,7 with seqlens [200, 200, 200, 200]
212
+ # Group 2 (uid=2): indices 8,9,10,11 with seqlens [150, 150, 150, 150]
213
+ # Group 3 (uid=3): indices 12,13,14,15 with seqlens [50, 50, 50, 50]
214
+ seqlen_list = [100] * 4 + [200] * 4 + [150] * 4 + [50] * 4
215
+ uid_list = [0] * 4 + [1] * 4 + [2] * 4 + [3] * 4
216
+
217
+ # Partition into 2 groups
218
+ partitions = get_group_balanced_partitions(seqlen_list, uid_list, k_partitions=2)
219
+
220
+ assert len(partitions) == 2
221
+
222
+ # Verify all indices are covered
223
+ all_indices = set()
224
+ for partition in partitions:
225
+ all_indices.update(partition)
226
+ assert all_indices == set(range(16))
227
+
228
+ # Verify same-uid samples stay together
229
+ for partition in partitions:
230
+ uids_in_partition = set(uid_list[i] for i in partition)
231
+ for uid in uids_in_partition:
232
+ # All samples with this uid should be in this partition
233
+ uid_indices = [i for i, u in enumerate(uid_list) if u == uid]
234
+ assert all(i in partition for i in uid_indices), f"uid {uid} samples split across partitions"
235
+
236
+
237
+ def test_group_balanced_partitions_single_sample_groups():
238
+ """Test group balancing with single-sample groups (n=1)."""
239
+ from verl.utils.seqlen_balancing import get_group_balanced_partitions
240
+
241
+ # Each sample is its own group
242
+ seqlen_list = [100, 200, 150, 50, 300, 250]
243
+ uid_list = [0, 1, 2, 3, 4, 5]
244
+
245
+ partitions = get_group_balanced_partitions(seqlen_list, uid_list, k_partitions=2)
246
+
247
+ assert len(partitions) == 2
248
+ all_indices = set()
249
+ for partition in partitions:
250
+ all_indices.update(partition)
251
+ assert all_indices == set(range(6))
252
+
253
+
254
+ def test_group_balanced_partitions_equal_size():
255
+ """Test group balancing with equal_size constraint simulation."""
256
+ from verl.utils.seqlen_balancing import get_group_balanced_partitions
257
+
258
+ # 8 groups, partition into 4 (simulating world_size=4)
259
+ # Each group has 2 samples
260
+ seqlen_list = [100, 100, 200, 200, 150, 150, 50, 50, 300, 300, 250, 250, 180, 180, 120, 120]
261
+ uid_list = [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7]
262
+
263
+ partitions = get_group_balanced_partitions(seqlen_list, uid_list, k_partitions=4)
264
+
265
+ assert len(partitions) == 4
266
+
267
+ # Verify all indices are covered
268
+ all_indices = set()
269
+ for partition in partitions:
270
+ all_indices.update(partition)
271
+ assert all_indices == set(range(16))
272
+
273
+ # Verify same-uid samples stay together
274
+ for partition in partitions:
275
+ uids_in_partition = set(uid_list[i] for i in partition)
276
+ for uid in uids_in_partition:
277
+ uid_indices = [i for i, u in enumerate(uid_list) if u == uid]
278
+ assert all(i in partition for i in uid_indices)
code/RL_model/verl/verl_train/tests/utils/test_shared_memory.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import multiprocessing
16
+ import unittest
17
+ from multiprocessing import shared_memory
18
+
19
+ import torch
20
+
21
+ from verl.workers.rollout.vllm_rollout.utils import create_shared_memory, rebuild_shared_memory
22
+
23
+
24
+ class TestSharedMemory(unittest.TestCase):
25
+ """Test cases for shared memory utility functions."""
26
+
27
+ def setUp(self):
28
+ """Set up test fixtures before each test method."""
29
+ # Use short unique names to avoid POSIX shared memory name length limits
30
+ import uuid
31
+
32
+ short_id = uuid.uuid4().hex[:8]
33
+ self.test_name = f"shm_{short_id}"
34
+
35
+ def tearDown(self):
36
+ """Clean up shared memory after each test method."""
37
+ # Note: We're relying on the OS to clean up shared memory
38
+ # as we properly delete all references in the tests
39
+ pass
40
+
41
+ def test_create_shared_memory_new(self):
42
+ """Test creating new shared memory with unique name."""
43
+ size = 1024
44
+
45
+ shm = create_shared_memory(size, self.test_name)
46
+
47
+ # Verify shared memory object is created correctly
48
+ self.assertIsNotNone(shm)
49
+ # Note: shared memory may have system-dependent size rounding
50
+ self.assertGreaterEqual(shm.size, size)
51
+ self.assertEqual(shm.name, self.test_name)
52
+
53
+ # Clean up - delete tensor references first
54
+ del shm
55
+
56
+ def test_create_shared_memory_attach_existing(self):
57
+ """Test that create_shared_memory attaches to existing shared memory when FileExistsError occurs."""
58
+ size = 2048
59
+
60
+ # First, create shared memory
61
+ shm1 = create_shared_memory(size, self.test_name)
62
+ self.assertGreaterEqual(shm1.size, size)
63
+
64
+ # Second call should attach to existing memory
65
+ shm2 = create_shared_memory(size, self.test_name)
66
+
67
+ # Verify we attached to the same shared memory
68
+ self.assertIsNotNone(shm2)
69
+ self.assertGreaterEqual(shm2.size, size)
70
+ self.assertEqual(shm2.name, self.test_name)
71
+
72
+ # Both should reference the same shared memory
73
+ self.assertEqual(shm1.name, shm2.name)
74
+
75
+ # Clean up
76
+ del shm1, shm2
77
+
78
+ def test_rebuild_shared_memory_default_dtype(self):
79
+ """Test rebuilding tensor from shared memory with default dtype (uint8)."""
80
+ size = 1024
81
+
82
+ # Create and write to shared memory
83
+ shm = create_shared_memory(size, self.test_name)
84
+ test_data = torch.arange(size, dtype=torch.uint8)
85
+ shm.buf[:size] = test_data.numpy().tobytes()
86
+
87
+ # Rebuild tensor from shared memory
88
+ tensor, _ = rebuild_shared_memory(self.test_name, size)
89
+
90
+ # Verify tensor properties
91
+ self.assertEqual(tensor.dtype, torch.uint8)
92
+ self.assertEqual(len(tensor), size)
93
+
94
+ # Verify data integrity
95
+ reconstructed = torch.frombuffer(shm.buf[:size], dtype=torch.uint8)
96
+ self.assertTrue(torch.equal(tensor, reconstructed))
97
+
98
+ # Clean up - delete references before closing
99
+ del tensor, reconstructed
100
+
101
+ def test_rebuild_shared_memory_custom_dtype(self):
102
+ """Test rebuilding tensor from shared memory with custom dtype."""
103
+ size = 256 # 256 bytes = 64 float32 values
104
+
105
+ # Create and write to shared memory
106
+ shm = create_shared_memory(size, self.test_name)
107
+ test_data = torch.arange(64, dtype=torch.float32)
108
+ shm.buf[:size] = test_data.numpy().tobytes()
109
+
110
+ # Rebuild tensor with custom dtype
111
+ tensor, _ = rebuild_shared_memory(self.test_name, size, dtype=torch.float32)
112
+
113
+ # Verify tensor properties
114
+ self.assertEqual(tensor.dtype, torch.float32)
115
+ self.assertEqual(len(tensor), 64)
116
+
117
+ # Verify data integrity
118
+ reconstructed = torch.frombuffer(shm.buf[:size], dtype=torch.float32)
119
+ self.assertTrue(torch.equal(tensor, reconstructed))
120
+
121
+ # Clean up - delete references before closing
122
+ del tensor, reconstructed
123
+
124
+ def test_shared_memory_data_integrity(self):
125
+ """Test that data remains intact between create and rebuild operations."""
126
+ size = 512
127
+
128
+ # Create test data with various patterns
129
+ test_data = torch.randint(0, 256, (size,), dtype=torch.uint8)
130
+
131
+ # Create shared memory and write data
132
+ shm = create_shared_memory(size, self.test_name)
133
+ shm.buf[:size] = test_data.numpy().tobytes()
134
+
135
+ # Rebuild tensor
136
+ tensor, _ = rebuild_shared_memory(self.test_name, size)
137
+
138
+ # Verify data integrity
139
+ reconstructed = torch.frombuffer(shm.buf[:size], dtype=torch.uint8)
140
+ self.assertTrue(torch.equal(test_data, reconstructed))
141
+
142
+ # Clean up - delete references before closing
143
+ del tensor, reconstructed
144
+
145
+ def test_shared_memory_different_dtypes(self):
146
+ """Test shared memory operations with different tensor dtypes."""
147
+ test_cases = [
148
+ (torch.float32, 256, 64), # 256 bytes / 4 bytes = 64 values
149
+ (torch.float64, 256, 32), # 256 bytes / 8 bytes = 32 values
150
+ (torch.int32, 256, 64), # 256 bytes / 4 bytes = 64 values
151
+ (torch.int64, 256, 32), # 256 bytes / 8 bytes = 32 values
152
+ (torch.uint8, 256, 256), # 256 bytes / 1 byte = 256 values
153
+ ]
154
+
155
+ for dtype, size, expected_len in test_cases:
156
+ # Create test data
157
+ test_data = torch.arange(expected_len, dtype=dtype)
158
+
159
+ # Create shared memory and write data
160
+ shm = create_shared_memory(size, self.test_name)
161
+ shm.buf[:size] = test_data.numpy().tobytes()
162
+
163
+ # Rebuild tensor
164
+ tensor, _ = rebuild_shared_memory(self.test_name, size, dtype=dtype)
165
+
166
+ # Verify properties and data
167
+ self.assertEqual(tensor.dtype, dtype)
168
+ self.assertEqual(len(tensor), expected_len)
169
+
170
+ reconstructed = torch.frombuffer(shm.buf[:size], dtype=dtype)
171
+ self.assertTrue(torch.equal(test_data, reconstructed))
172
+
173
+ # Clean up - delete references before closing
174
+ del tensor, reconstructed
175
+
176
+ def test_shared_memory_multiple_operations(self):
177
+ """Test multiple create/rebuild operations with the same name."""
178
+ size = 512
179
+
180
+ # First iteration
181
+ test_data1 = torch.arange(size, dtype=torch.uint8)
182
+ shm1 = create_shared_memory(size, self.test_name)
183
+ shm1.buf[:size] = test_data1.numpy().tobytes()
184
+ tensor1, _ = rebuild_shared_memory(self.test_name, size)
185
+ reconstructed1 = torch.frombuffer(shm1.buf[:size], dtype=torch.uint8)
186
+ self.assertTrue(torch.equal(test_data1, reconstructed1))
187
+ del tensor1, reconstructed1, shm1
188
+
189
+ # Second iteration with different data
190
+ test_data2 = torch.arange(size, dtype=torch.uint8) * 2
191
+ shm2 = create_shared_memory(size, self.test_name)
192
+ shm2.buf[:size] = test_data2.numpy().tobytes()
193
+ tensor2, _ = rebuild_shared_memory(self.test_name, size)
194
+ reconstructed2 = torch.frombuffer(shm2.buf[:size], dtype=torch.uint8)
195
+ self.assertTrue(torch.equal(test_data2, reconstructed2))
196
+ del tensor2, reconstructed2, shm2
197
+
198
+
199
+ # Module-level function for cross-process testing
200
+ def child_process_function(name, size, test_data_bytes):
201
+ """Child process function to rebuild and verify tensor."""
202
+ shm = None
203
+ tensor = None
204
+ test_data = None
205
+ try:
206
+ # Convert bytes back to tensor
207
+ test_data = torch.frombuffer(test_data_bytes, dtype=torch.uint8)
208
+
209
+ # Attach to shared memory
210
+ shm = shared_memory.SharedMemory(name=name)
211
+
212
+ # Rebuild tensor from shared memory
213
+ tensor = torch.frombuffer(shm.buf[:size], dtype=torch.uint8)
214
+
215
+ # Verify data integrity
216
+ assert torch.equal(test_data, tensor), "Data mismatch in child process"
217
+ return True
218
+ except Exception as e:
219
+ print(f"Error in child process: {e}")
220
+ return False
221
+ finally:
222
+ # Clean up shared memory in child process
223
+ # Delete all references first
224
+ del tensor, test_data
225
+ if shm is not None:
226
+ shm.close()
227
+ # Note: Don't unlink in child process, parent will clean up
228
+
229
+
230
+ class TestSharedMemoryIntegration(unittest.TestCase):
231
+ """Integration tests for shared memory operations across process boundaries."""
232
+
233
+ def test_cross_process_shared_memory(self):
234
+ """Test shared memory can be created in one process and accessed in another."""
235
+ size = 1024
236
+ test_data = torch.arange(size, dtype=torch.uint8)
237
+
238
+ # Create shared memory in parent process
239
+ shm = create_shared_memory(size, "test_cross_proc")
240
+ shm.buf[:size] = test_data.numpy().tobytes()
241
+
242
+ # Convert tensor to bytes for passing to child process
243
+ test_data_bytes = test_data.numpy().tobytes()
244
+
245
+ # Start child process
246
+ process = multiprocessing.Process(
247
+ target=child_process_function, args=("test_cross_proc", size, test_data_bytes)
248
+ )
249
+ process.start()
250
+ process.join(timeout=5)
251
+
252
+ # Verify child process completed successfully
253
+ self.assertEqual(process.exitcode, 0, "Child process failed")
254
+
255
+ # Clean up
256
+ del shm
257
+
258
+
259
+ if __name__ == "__main__":
260
+ unittest.main()
code/RL_model/verl/verl_train/tests/utils/test_special_linear_cross_entropy_tp.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+
18
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
19
+ #
20
+ # Licensed under the Apache License, Version 2.0 (the "License");
21
+ # you may not use this file except in compliance with the License.
22
+ # You may obtain a copy of the License at
23
+ #
24
+ # http://www.apache.org/licenses/LICENSE-2.0
25
+ #
26
+ # Unless required by applicable law or agreed to in writing, software
27
+ # distributed under the License is distributed on an "AS IS" BASIS,
28
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29
+ # See the License for the specific language governing permissions and
30
+ # limitations under the License.
31
+
32
+ import os
33
+
34
+ import torch
35
+ import torch.distributed as dist
36
+
37
+ try:
38
+ from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy
39
+ except ImportError:
40
+ # FIXME: remove these manually included paths
41
+ import sys
42
+
43
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../")))
44
+ finally:
45
+ from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy
46
+
47
+ import verl.utils.torch_functional as verl_F
48
+
49
+ compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True)
50
+
51
+ MAX_TEST_CASES = os.environ.get("MAX_TEST_CASES", 5)
52
+ VERIFY_TORCH_SELF = os.environ.get("VERIFY_TORCH_SELF", False)
53
+ LOW_MEMORY = os.environ.get("LOW_MEMORY", False)
54
+ LOW_MEMORY_DIV_FACTOR = os.environ.get("LOW_MEMORY_DIV_FACTOR", 16)
55
+
56
+
57
+ def run_torch_entropy(
58
+ hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, temperature: float, reduction="none"
59
+ ) -> list[torch.Tensor]:
60
+ # [num_tokens, vocab_size]
61
+ if len(hidden.shape) > 2:
62
+ hidden = hidden.view(-1, hidden.shape[-1]) # [num_tokens, hidden_size]
63
+ if len(labels.shape) > 1:
64
+ labels = labels.view(-1)
65
+ logits = torch.matmul(
66
+ hidden.to(torch.float32),
67
+ weight.to(torch.float32) if weight.size(0) == hidden.size(1) else weight.T.to(torch.float32),
68
+ )
69
+ logits /= temperature
70
+ pd = torch.nn.functional.softmax(logits, dim=-1) # [num_tokens, vocab_size]
71
+ entropy_a = torch.logsumexp(logits, dim=-1) # [num_tokens]
72
+ entropy_b = torch.sum(pd * logits, dim=-1) # [num_tokens]
73
+ entropy = entropy_a - entropy_b
74
+ logprobs = torch.nn.functional.cross_entropy(logits, labels, reduction=reduction) # [num_tokens]
75
+ logprobs = torch.neg(logprobs)
76
+ return logprobs, entropy
77
+
78
+
79
+ class TorchEntropyTP(torch.autograd.Function):
80
+ """
81
+ it is used for testing the correctness of the kernel
82
+ it is not efficient and is not recommended to use in practice
83
+ """
84
+
85
+ @staticmethod
86
+ def forward(
87
+ ctx,
88
+ hidden: torch.Tensor,
89
+ weight: torch.Tensor,
90
+ labels: torch.Tensor,
91
+ temperature: float,
92
+ dist_process_group: torch.distributed.ProcessGroup,
93
+ ):
94
+ # weight has shape [vocab_size, hidden_size], hidden has shape [num_tokens, hidden_size]
95
+ ctx.original_hidden_shape = hidden.shape
96
+ if len(hidden.shape) > 2:
97
+ hidden = hidden.view(-1, hidden.shape[-1]) # [num_tokens, hidden_size]
98
+ if len(labels.shape) > 1:
99
+ labels = labels.view(-1)
100
+
101
+ logits = torch.matmul(hidden.to(torch.float32), weight.to(torch.float32).T) # [num_tokens, vocab_size]
102
+ logits /= temperature
103
+ whole_logits = torch.empty(
104
+ (logits.shape[0], logits.shape[1] * dist.get_world_size(dist_process_group)),
105
+ dtype=logits.dtype,
106
+ device=logits.device,
107
+ )
108
+ whole_logits_ref = [
109
+ whole_logits[:, i * logits.shape[1] : (i + 1) * logits.shape[1]]
110
+ for i in range(dist.get_world_size(dist_process_group))
111
+ ]
112
+ dist.all_gather(whole_logits_ref, logits, group=dist_process_group)
113
+
114
+ pd = torch.nn.functional.softmax(whole_logits, dim=-1)
115
+ entropy_a = torch.logsumexp(whole_logits, dim=-1) # [num_tokens]
116
+ entropy_b = torch.sum(pd * whole_logits, dim=-1) # [num_tokens]
117
+ entropy = entropy_a - entropy_b
118
+
119
+ logprobs = torch.nn.functional.cross_entropy(whole_logits, labels, reduction="none")
120
+ logprobs = torch.neg(logprobs)
121
+
122
+ ctx.save_for_backward(hidden, weight, labels, whole_logits, entropy_b)
123
+ ctx.dist_process_group = dist_process_group
124
+ ctx.temperature = temperature
125
+ return logprobs, entropy
126
+
127
+ @staticmethod
128
+ def backward(ctx, g_logprobs: torch.Tensor, g_entropy: torch.Tensor):
129
+ hidden, weight, labels, whole_logits, entropy_b = ctx.saved_tensors
130
+ dist_process_group = ctx.dist_process_group
131
+ temperature = ctx.temperature
132
+ batch_size, hidden_size = hidden.shape
133
+ vocab_size, hidden_size = weight.shape
134
+ rank = dist.get_rank(dist_process_group)
135
+
136
+ # Compute softmax probabilities
137
+ maximum, _ = torch.max(whole_logits, dim=-1, keepdim=True)
138
+ exp_logits = torch.exp(whole_logits - maximum)
139
+ accumulate = exp_logits.sum(dim=-1, keepdim=True)
140
+ pd = exp_logits / accumulate
141
+
142
+ # Gradient for entropy
143
+ # entropy = entropy_a - entropy_b
144
+ # entropy_a = log(sum(exp(logits)))
145
+ # entropy_b = sum(pd * logits)
146
+ # d_entropy_a/d_logits = pd
147
+ # d_entropy_b/d_logits = pd * (logits - b.unsqueeze(1) + 1)
148
+ # d_entropy/d_logits = d_entropy_a - d_entropy_b
149
+ # d_entropy/d_logits = pd - pd * (logits - b.unsqueeze(1) + 1)
150
+ # d_entropy/d_logits = -pd * (logits - b.unsqueeze(1))
151
+ d_logits_entropy = g_entropy.unsqueeze(1) * (-pd * (whole_logits - entropy_b.unsqueeze(1)))
152
+
153
+ # Gradient for logprobs
154
+ # logprobs = -cross_entropy = -log(pd[labels])
155
+ # d_logprobs/d_logits = (pd - one_hot(labels))
156
+ one_hot = torch.zeros_like(whole_logits)
157
+ one_hot.scatter_(1, labels.unsqueeze(1), 1)
158
+ g_logprobs = torch.neg(g_logprobs)
159
+ d_logits_logprobs = g_logprobs.unsqueeze(1) * (pd - one_hot)
160
+ # NOTE: This will lead to wrong result
161
+ # d_logits_logprobs = g_logprobs.unsqueeze(1) * (pd - 1) * one_hot
162
+
163
+ # Combine gradients
164
+ d_logits = d_logits_entropy + d_logits_logprobs
165
+ d_logits /= temperature
166
+
167
+ # Get local slice of gradients
168
+ local_d_logits = d_logits[:, rank * vocab_size : (rank + 1) * vocab_size]
169
+
170
+ # Compute gradients for hidden and weight
171
+ d_hidden = torch.matmul(local_d_logits, weight.to(torch.float32))
172
+ d_weight = torch.matmul(local_d_logits.T, hidden.to(torch.float32))
173
+ d_hidden = d_hidden.view(ctx.original_hidden_shape)
174
+
175
+ return d_hidden, d_weight, None, None, None
176
+
177
+
178
+ run_torch_entropy_tp = TorchEntropyTP.apply
179
+
180
+
181
+ class TestLinearCrossEntropy_TensorParallel:
182
+ def __init__(self):
183
+ dist.init_process_group(backend="nccl")
184
+ self.group = dist.group.WORLD
185
+
186
+ self.local_rank = dist.get_rank(self.group)
187
+ self.world_size = dist.get_world_size(self.group)
188
+ device = torch.device(f"cuda:{self.local_rank}")
189
+ torch.cuda.set_device(device)
190
+ print(f"[INFO]: Local rank: {self.local_rank}, World size: {self.world_size}")
191
+
192
+ def initialize(self, test_case_idx: int, temperature: float = 1.5):
193
+ self.test_case_idx = test_case_idx
194
+ self.temperature = temperature
195
+
196
+ def shutdown(self):
197
+ dist.destroy_process_group()
198
+
199
+ def cleanup(self):
200
+ torch.cuda.empty_cache()
201
+ torch.cuda.reset_peak_memory_stats()
202
+ import gc
203
+
204
+ gc.collect()
205
+ torch.cuda.synchronize()
206
+
207
+ def generate_hyper(self):
208
+ global LOW_MEMORY, LOW_MEMORY_DIV_FACTOR, MAX_TEST_CASES
209
+
210
+ self.dtype = torch.bfloat16
211
+ if self.test_case_idx == 0:
212
+ self.batch_size = 1
213
+ self.num_tokens = 1937
214
+ self.hidden_size = 3584
215
+ self.vocab_size = 152064
216
+ elif self.test_case_idx == 1:
217
+ self.batch_size = 1
218
+ self.num_tokens = 2169
219
+ self.hidden_size = 896
220
+ self.vocab_size = 151936
221
+ elif self.test_case_idx == 2:
222
+ self.batch_size = 1
223
+ self.num_tokens = 1530
224
+ self.hidden_size = 2048
225
+ self.vocab_size = 32256
226
+ elif self.test_case_idx == 3:
227
+ self.batch_size = 1
228
+ self.num_tokens = 1388
229
+ self.hidden_size = 4096
230
+ self.vocab_size = 102400
231
+ elif self.test_case_idx == 4:
232
+ self.batch_size = 1
233
+ self.num_tokens = 8192
234
+ self.hidden_size = 4096
235
+ self.vocab_size = 102400
236
+ else:
237
+ raise ValueError(f"Invalid test case index: {self.test_case_idx}")
238
+ if LOW_MEMORY:
239
+ self.vocab_size = int(self.vocab_size / LOW_MEMORY_DIV_FACTOR)
240
+ assert MAX_TEST_CASES <= 5, "MAX_TEST_CASES should be less than or equal to 5."
241
+
242
+ def generate_forward_inputs(self):
243
+ hidden = (
244
+ torch.empty((self.batch_size, self.num_tokens, self.hidden_size), dtype=self.dtype, device="cuda")
245
+ .uniform_(-0.5, 0.5)
246
+ .requires_grad_()
247
+ )
248
+ weight = (
249
+ torch.empty((self.vocab_size, self.hidden_size), dtype=self.dtype, device="cuda")
250
+ .uniform_(-0.5, 0.5)
251
+ .requires_grad_()
252
+ )
253
+ labels = torch.randint(0, self.vocab_size, (self.batch_size, self.num_tokens), device="cuda")
254
+ return hidden, weight, labels
255
+
256
+ def generate_backward_inputs(self):
257
+ g_entropy = torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-0.5, 0.5)
258
+ g_logprobs = torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-1, 1)
259
+ return g_entropy, g_logprobs
260
+
261
+ def verify_torch_itself(self, iterations: int = 5):
262
+ self.cleanup()
263
+ self.generate_hyper()
264
+
265
+ for i in range(iterations):
266
+ hidden, weight, labels = self.generate_forward_inputs()
267
+
268
+ # NOTE: we need to manually synchronize hidden and labels among Process Group
269
+ dist.broadcast(hidden, src=0, group=self.group)
270
+ dist.broadcast(labels, src=0, group=self.group)
271
+
272
+ # forward pass
273
+ # Create a tensor to hold the gathered weights from all ranks
274
+ # weight has shape [vocab_size, hidden_size]
275
+ # We want to gather along the first dimension to get [vocab_size * world_size, hidden_size]
276
+
277
+ # Create a single contiguous tensor to hold all gathered weights
278
+ whole_weight = torch.empty(
279
+ (self.vocab_size * self.world_size, self.hidden_size), dtype=weight.dtype, device=weight.device
280
+ )
281
+
282
+ # Create views into the tensor for each rank's portion
283
+ whole_weight_views = [
284
+ whole_weight[i * self.vocab_size : (i + 1) * self.vocab_size] for i in range(self.world_size)
285
+ ]
286
+
287
+ # Perform all_gather operation using the views
288
+ dist.all_gather(whole_weight_views, weight, group=self.group)
289
+
290
+ # Set requires_grad for autograd
291
+ whole_weight.requires_grad_()
292
+
293
+ (single_logprobs, single_entropy) = run_torch_entropy(hidden, whole_weight, labels, self.temperature)
294
+
295
+ (tp_logprobs, tp_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group)
296
+
297
+ torch.testing.assert_close(single_logprobs, tp_logprobs, atol=1e-4, rtol=1e-4)
298
+ torch.testing.assert_close(single_entropy, tp_entropy, atol=1e-4, rtol=1e-4)
299
+
300
+ # backward pass
301
+ g_entropy, g_logprobs = self.generate_backward_inputs()
302
+ # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group
303
+ dist.broadcast(g_entropy, src=0, group=self.group)
304
+ dist.broadcast(g_logprobs, src=0, group=self.group)
305
+
306
+ (single_d_hidden, single_d_weight) = torch.autograd.grad(
307
+ (single_entropy, single_logprobs), (hidden, whole_weight), (g_entropy, g_logprobs), retain_graph=False
308
+ )
309
+
310
+ (tp_d_hidden, tp_d_weight) = torch.autograd.grad(
311
+ (tp_entropy, tp_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
312
+ )
313
+ # NOTE: all-reduce on hidden is conducted outside the kernel
314
+ dist.all_reduce(tp_d_hidden, op=dist.ReduceOp.SUM, group=self.group)
315
+
316
+ torch.testing.assert_close(tp_d_hidden, single_d_hidden, atol=1e-2, rtol=1e-4)
317
+ # Extract the corresponding slice from single_d_weight for comparison
318
+ # tp_d_weight has shape [vocab_size, hidden_size]
319
+ # single_d_weight has shape [vocab_size * world_size, hidden_size]
320
+ torch.testing.assert_close(
321
+ tp_d_weight,
322
+ single_d_weight[self.local_rank * self.vocab_size : (self.local_rank + 1) * self.vocab_size],
323
+ atol=1e-2,
324
+ rtol=1e-4,
325
+ )
326
+
327
+ # atol=1e-3, rtol=1e-4)
328
+ if self.local_rank == 0:
329
+ print("[PASS] torch TP correctness is verified")
330
+
331
+ def check_torch_storage(self):
332
+ self.cleanup()
333
+ self.generate_hyper()
334
+
335
+ hidden, weight, labels = self.generate_forward_inputs()
336
+
337
+ # NOTE: we need to manually synchronize hidden and labels among Process Group
338
+ dist.broadcast(hidden, src=0, group=self.group)
339
+ dist.broadcast(labels, src=0, group=self.group)
340
+
341
+ torch.cuda.reset_peak_memory_stats()
342
+ (tp_logprobs, tp_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group)
343
+ torch.cuda.synchronize()
344
+ forward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024
345
+
346
+ g_entropy, g_logprobs = self.generate_backward_inputs()
347
+ # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group
348
+ dist.broadcast(g_entropy, src=0, group=self.group)
349
+ dist.broadcast(g_logprobs, src=0, group=self.group)
350
+
351
+ torch.cuda.reset_peak_memory_stats()
352
+ (d_tp_hidden, d_tp_weight) = torch.autograd.grad(
353
+ (tp_entropy, tp_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
354
+ )
355
+ torch.cuda.synchronize()
356
+ backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024
357
+ # NOTE: all-reduce on hidden is conducted outside the kernel
358
+ dist.all_reduce(d_tp_hidden, op=dist.ReduceOp.SUM, group=self.group)
359
+
360
+ if self.local_rank == 0:
361
+ print(f"[INFO]: Torch Forward pass peak memory: {forward_max_memory:.2f} MB")
362
+ print(f"[INFO]: Torch Backward pass peak memory: {backward_max_memory:.2f} MB")
363
+
364
+ def verify_kernel_correctness(self, iterations: int = 5):
365
+ self.cleanup()
366
+ self.generate_hyper()
367
+
368
+ torch_forward_latency = list()
369
+ torch_backward_latency = list()
370
+ kernel_forward_latency = list()
371
+ kernel_backward_latency = list()
372
+
373
+ start_event = torch.cuda.Event(enable_timing=True)
374
+ end_event = torch.cuda.Event(enable_timing=True)
375
+
376
+ for i in range(iterations):
377
+ hidden, weight, labels = self.generate_forward_inputs()
378
+
379
+ # NOTE: we need to manually synchronize hidden and labels among Process Group
380
+ dist.broadcast(hidden, src=0, group=self.group)
381
+ dist.broadcast(labels, src=0, group=self.group)
382
+
383
+ start_event.record()
384
+ (torch_logprobs, torch_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.temperature, self.group)
385
+ end_event.record()
386
+ torch.cuda.synchronize()
387
+ torch_forward_latency.append(start_event.elapsed_time(end_event))
388
+
389
+ start_event.record()
390
+ (kernel_logprobs, kernel_entropy) = linear_cross_entropy(
391
+ hidden, weight, labels, self.temperature, "none", self.group
392
+ )
393
+ end_event.record()
394
+ torch.cuda.synchronize()
395
+ kernel_forward_latency.append(start_event.elapsed_time(end_event))
396
+
397
+ torch.testing.assert_close(torch_logprobs, kernel_logprobs, atol=1e-1, rtol=1e-2)
398
+ torch.testing.assert_close(torch_entropy, kernel_entropy, atol=1e-1, rtol=1e-2)
399
+
400
+ # backward pass
401
+ g_entropy, g_logprobs = self.generate_backward_inputs()
402
+ # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group
403
+ dist.broadcast(g_entropy, src=0, group=self.group)
404
+ dist.broadcast(g_logprobs, src=0, group=self.group)
405
+
406
+ start_event.record()
407
+ (torch_d_hidden, torch_d_weight) = torch.autograd.grad(
408
+ (torch_entropy, torch_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
409
+ )
410
+ end_event.record()
411
+ torch.cuda.synchronize()
412
+ torch_backward_latency.append(start_event.elapsed_time(end_event))
413
+ # NOTE: all-reduce on hidden is conducted outside the kernel
414
+ dist.all_reduce(torch_d_hidden, op=dist.ReduceOp.SUM, group=self.group)
415
+
416
+ start_event.record()
417
+ (kernel_d_hidden, kernel_d_weight) = torch.autograd.grad(
418
+ (kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
419
+ )
420
+ end_event.record()
421
+ torch.cuda.synchronize()
422
+ kernel_backward_latency.append(start_event.elapsed_time(end_event))
423
+ # NOTE: all-reduce on hidden is conducted outside the kernel
424
+ dist.all_reduce(kernel_d_hidden, op=dist.ReduceOp.SUM, group=self.group)
425
+
426
+ torch.testing.assert_close(torch_d_hidden, kernel_d_hidden, atol=2e-2, rtol=4e-2)
427
+ torch.testing.assert_close(torch_d_weight, kernel_d_weight, atol=2e-2, rtol=4e-2)
428
+
429
+ # remove first latency
430
+ torch_forward_latency = torch_forward_latency[1:]
431
+ torch_backward_latency = torch_backward_latency[1:]
432
+ kernel_forward_latency = kernel_forward_latency[1:]
433
+ kernel_backward_latency = kernel_backward_latency[1:]
434
+
435
+ if self.local_rank == 0:
436
+ print("\n[PASS]: Verified kernel forward & backward correctness.")
437
+
438
+ print(
439
+ f"[INFO]: Forward pass: Torch implementation average time: "
440
+ f"{sum(torch_forward_latency) / len(torch_forward_latency):.2f} ms"
441
+ )
442
+ print(
443
+ f"[INFO]: Backward pass: torch implementation average time: "
444
+ f"{sum(torch_backward_latency) / len(torch_backward_latency):.2f} ms"
445
+ )
446
+ print(
447
+ f"[INFO]: Forward pass: Kernel implementation average time: "
448
+ f"{sum(kernel_forward_latency) / len(kernel_forward_latency):.2f} ms"
449
+ )
450
+ print(
451
+ f"[INFO]: Backward pass: kernel implementation average time: "
452
+ f"{sum(kernel_backward_latency) / len(kernel_backward_latency):.2f} ms"
453
+ )
454
+
455
+ def check_kernel_storage(self):
456
+ self.cleanup()
457
+ self.generate_hyper()
458
+
459
+ hidden, weight, labels = self.generate_forward_inputs()
460
+
461
+ # NOTE: we need to manually synchronize hidden and labels among Process Group
462
+ dist.broadcast(hidden, src=0, group=self.group)
463
+ dist.broadcast(labels, src=0, group=self.group)
464
+
465
+ torch.cuda.reset_peak_memory_stats()
466
+ (kernel_logprobs, kernel_entropy) = linear_cross_entropy(
467
+ hidden, weight, labels, self.temperature, "none", self.group
468
+ )
469
+ torch.cuda.synchronize()
470
+ kernel_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024
471
+
472
+ g_entropy, g_logprobs = self.generate_backward_inputs()
473
+ # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group
474
+ dist.broadcast(g_entropy, src=0, group=self.group)
475
+ dist.broadcast(g_logprobs, src=0, group=self.group)
476
+
477
+ torch.cuda.reset_peak_memory_stats()
478
+ (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad(
479
+ (kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
480
+ )
481
+ torch.cuda.synchronize()
482
+ kernel_backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024
483
+ # NOTE: all-reduce on hidden is conducted outside the kernel
484
+ dist.all_reduce(d_kernel_hidden, op=dist.ReduceOp.SUM, group=self.group)
485
+
486
+ if self.local_rank == 0:
487
+ print(f"[INFO]: Kernel Forward pass peak memory: {kernel_max_memory:.2f} MB")
488
+ print(f"[INFO]: Kernel Backward pass peak memory: {kernel_backward_max_memory:.2f} MB")
489
+
490
+
491
+ if __name__ == "__main__":
492
+ # TP command: torchrun --standalone --nnodes=1 --nproc-per-node=2 tests/kernels/test_linear_cross_entropy_tp.py
493
+
494
+ # Check if running with torchrun (distributed mode)
495
+ assert int(os.environ["WORLD_SIZE"]) > 1, (
496
+ "[ERROR]: This test is designed to run in distributed mode with torchrun. Please use torchrun to "
497
+ "execute this script."
498
+ )
499
+ torch.manual_seed(233376 + int(os.environ.get("RANK", 0)))
500
+
501
+ # set_backward_method(BackwardEnum._Total_Fuse_MN)
502
+ # set_backward_method(BackwardEnum._Split_Dlogits_N)
503
+
504
+ test = TestLinearCrossEntropy_TensorParallel()
505
+ for test_case_idx in range(MAX_TEST_CASES):
506
+ print(f"[INFO] Running test case {test_case_idx}")
507
+ test.initialize(test_case_idx)
508
+ if VERIFY_TORCH_SELF:
509
+ test.verify_torch_itself()
510
+ test.check_torch_storage()
511
+ test.verify_kernel_correctness()
512
+ test.check_kernel_storage()
513
+
514
+ test.shutdown()
code/RL_model/verl/verl_train/tests/utils/test_special_mstx_profile.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import unittest
16
+ from unittest.mock import MagicMock, patch
17
+
18
+ from verl.utils.profiler.config import NPUToolConfig, ProfilerConfig
19
+ from verl.utils.profiler.mstx_profile import NPUProfiler
20
+ from verl.utils.profiler.profile import DistProfiler
21
+
22
+
23
+ class TestNPUProfilerInitialization(unittest.TestCase):
24
+ def setUp(self):
25
+ NPUProfiler._define_count = 0
26
+
27
+ def test_init_with_default_config(self):
28
+ tool_config = NPUToolConfig()
29
+ config = ProfilerConfig(tool="npu")
30
+ profiler = DistProfiler(rank=0, config=config, tool_config=tool_config)
31
+ self.assertFalse(profiler.check_enable())
32
+
33
+ def test_init_with_disabled_config(self):
34
+ config = ProfilerConfig(enable=False, tool="npu")
35
+ tool_config = NPUToolConfig()
36
+ profiler = DistProfiler(rank=0, config=config, tool_config=tool_config)
37
+ self.assertFalse(profiler.check_enable())
38
+
39
+ def test_init_with_all_ranks_true(self):
40
+ config = ProfilerConfig(enable=True, all_ranks=True, tool="npu")
41
+ tool_config = NPUToolConfig()
42
+ profiler = DistProfiler(rank=0, config=config, tool_config=tool_config)
43
+ self.assertTrue(profiler.check_this_rank())
44
+
45
+ def test_init_with_ranks_list(self):
46
+ config = ProfilerConfig(enable=True, ranks=[1, 2], tool="npu")
47
+ tool_config = NPUToolConfig()
48
+ profiler = DistProfiler(rank=1, config=config, tool_config=tool_config)
49
+ self.assertTrue(profiler.check_this_rank())
50
+
51
+ def test_init_with_rank_not_in_ranks(self):
52
+ config = ProfilerConfig(enable=True, ranks=[1, 2], tool="npu")
53
+ tool_config = NPUToolConfig()
54
+ profiler = DistProfiler(rank=3, config=config, tool_config=tool_config)
55
+ self.assertFalse(profiler.check_this_rank())
56
+
57
+
58
+ class TestNPUProfilerStart(unittest.TestCase):
59
+ def setUp(self):
60
+ NPUProfiler._define_count = 0
61
+ self.config = ProfilerConfig(enable=True, ranks=[0], tool="npu")
62
+ self.tool_config = NPUToolConfig(discrete=False)
63
+
64
+ @patch("verl.utils.profiler.mstx_profile.get_npu_profiler")
65
+ def test_start_when_enabled_and_this_rank(self, mock_get_profiler):
66
+ profiler = DistProfiler(rank=0, config=self.config, tool_config=self.tool_config)
67
+ profiler.start(role="worker", profile_step="1")
68
+ self.assertTrue(profiler.check_this_step())
69
+ self.assertEqual(NPUProfiler._define_count, 1)
70
+ mock_get_profiler.assert_called_once()
71
+
72
+ @patch("verl.utils.profiler.mstx_profile.get_npu_profiler")
73
+ def test_start_when_not_this_rank(self, mock_get_profiler):
74
+ profiler = DistProfiler(rank=1, config=self.config, tool_config=self.tool_config)
75
+ profiler.start()
76
+ self.assertFalse(profiler.check_this_step())
77
+ self.assertEqual(NPUProfiler._define_count, 0)
78
+ mock_get_profiler.assert_not_called()
79
+
80
+ @patch("verl.utils.profiler.mstx_profile.get_npu_profiler")
81
+ def test_start_discrete_mode_does_not_increase_count(self, mock_get_profiler):
82
+ tool_config = NPUToolConfig(discrete=True)
83
+ profiler = DistProfiler(rank=0, config=self.config, tool_config=tool_config)
84
+ profiler.start()
85
+ self.assertEqual(NPUProfiler._define_count, 0)
86
+ mock_get_profiler.assert_not_called()
87
+
88
+ @patch("verl.utils.profiler.mstx_profile.get_npu_profiler")
89
+ def test_multiple_start_calls_do_not_increase_count(self, mock_get_profiler):
90
+ profiler = DistProfiler(rank=0, config=self.config, tool_config=self.tool_config)
91
+ profiler.start()
92
+ profiler.start()
93
+ self.assertEqual(NPUProfiler._define_count, 1)
94
+ mock_get_profiler.assert_called_once()
95
+
96
+
97
+ class TestNPUProfilerStartStopInteraction(unittest.TestCase):
98
+ def setUp(self):
99
+ NPUProfiler._define_count = 0
100
+ self.config = ProfilerConfig(enable=True, ranks=[0], tool="npu")
101
+ self.tool_config = NPUToolConfig(discrete=False)
102
+
103
+ @patch("verl.utils.profiler.mstx_profile.get_npu_profiler")
104
+ def test_start_stop_cycle(self, mock_get_profiler):
105
+ mock_profile_npu = MagicMock()
106
+ mock_get_profiler.return_value = mock_profile_npu
107
+
108
+ profiler = DistProfiler(rank=0, config=self.config, tool_config=self.tool_config)
109
+ profiler.start()
110
+ self.assertEqual(NPUProfiler._define_count, 1)
111
+ self.assertEqual(mock_profile_npu.start.call_count, 1)
112
+ profiler.stop()
113
+ self.assertEqual(NPUProfiler._define_count, 0)
114
+ self.assertEqual(mock_profile_npu.step.call_count, 1)
115
+ self.assertEqual(mock_profile_npu.stop.call_count, 1)
116
+
117
+ @patch("verl.utils.profiler.mstx_profile.get_npu_profiler")
118
+ def test_multiple_instances_share_define_count(self, mock_get_profiler):
119
+ mock_profile_npu = MagicMock()
120
+ mock_get_profiler.return_value = mock_profile_npu
121
+
122
+ profiler1 = DistProfiler(rank=0, config=self.config, tool_config=self.tool_config)
123
+ profiler2 = DistProfiler(rank=0, config=self.config, tool_config=self.tool_config)
124
+ profiler1.start()
125
+ profiler2.start()
126
+ self.assertEqual(NPUProfiler._define_count, 1)
127
+ self.assertEqual(mock_profile_npu.start.call_count, 1)
128
+ profiler1.stop()
129
+ self.assertEqual(NPUProfiler._define_count, 0)
130
+
131
+
132
+ class TestNPUProfilerAnnotate(unittest.TestCase):
133
+ def setUp(self):
134
+ self.config = ProfilerConfig(enable=True, all_ranks=True, tool="npu")
135
+ self.tool_config = NPUToolConfig(discrete=False)
136
+ self.rank = 0
137
+
138
+ def test_annotate_decorator_applied_correctly(self):
139
+ mock_worker = MagicMock()
140
+ mock_worker.profiler = DistProfiler(rank=self.rank, config=self.config, tool_config=self.tool_config)
141
+ # Manually set private attribute for testing annotation in active step
142
+ mock_worker.profiler._this_step = True
143
+
144
+ mock_mark_range = "mocked_range_handle"
145
+
146
+ with (
147
+ patch("verl.utils.profiler.mstx_profile.mark_start_range") as mock_start_patch,
148
+ patch("verl.utils.profiler.mstx_profile.mark_end_range") as mock_end_patch,
149
+ ):
150
+ mock_start_patch.return_value = mock_mark_range
151
+
152
+ with patch("verl.utils.profiler.mstx_profile.get_npu_profiler") as mock_get_profiler:
153
+ decorator = mock_worker.profiler.annotate(message="test")
154
+
155
+ @decorator
156
+ def test_func(self, *args, **kwargs):
157
+ return "result"
158
+
159
+ result = test_func(mock_worker)
160
+
161
+ self.assertEqual(result, "result")
162
+ mock_start_patch.assert_called_once_with(message="test")
163
+ mock_end_patch.assert_called_once_with(mock_mark_range)
164
+ mock_get_profiler.assert_not_called()
165
+
166
+ def test_annotate_when_profiler_disabled(self):
167
+ disabled_config = ProfilerConfig(enable=False, tool="npu")
168
+ mock_worker = MagicMock()
169
+ mock_worker.profiler = DistProfiler(rank=self.rank, config=disabled_config, tool_config=self.tool_config)
170
+
171
+ with (
172
+ patch("verl.utils.profiler.mstx_profile.mark_start_range") as mock_start_patch,
173
+ patch("verl.utils.profiler.mstx_profile.mark_end_range") as mock_end_patch,
174
+ patch("verl.utils.profiler.mstx_profile.get_npu_profiler") as mock_get_profiler,
175
+ ):
176
+ decorator = mock_worker.profiler.annotate(message="test")
177
+
178
+ @decorator
179
+ def test_func(self, *args, **kwargs):
180
+ return "result"
181
+
182
+ result = test_func(mock_worker)
183
+
184
+ self.assertEqual(result, "result")
185
+ mock_start_patch.assert_not_called()
186
+ mock_end_patch.assert_not_called()
187
+ mock_get_profiler.assert_not_called()
188
+
189
+ def test_annotate_when_this_step_disabled(self):
190
+ mock_worker = MagicMock()
191
+ mock_worker.profiler = DistProfiler(rank=self.rank, config=self.config, tool_config=self.tool_config)
192
+ mock_worker.profiler._this_step = False
193
+
194
+ with (
195
+ patch("verl.utils.profiler.mstx_profile.mark_start_range") as mock_start_patch,
196
+ patch("verl.utils.profiler.mstx_profile.mark_end_range") as mock_end_patch,
197
+ patch("verl.utils.profiler.mstx_profile.get_npu_profiler") as mock_get_profiler,
198
+ ):
199
+ decorator = mock_worker.profiler.annotate(message="test")
200
+
201
+ @decorator
202
+ def test_func(self, *args, **kwargs):
203
+ return "result"
204
+
205
+ result = test_func(mock_worker)
206
+
207
+ self.assertEqual(result, "result")
208
+ mock_start_patch.assert_not_called()
209
+ mock_end_patch.assert_not_called()
210
+ mock_get_profiler.assert_not_called()
211
+
212
+ def test_annotate_discrete_mode_enabled(self):
213
+ discrete_tool_config = NPUToolConfig(discrete=True)
214
+ mock_worker = MagicMock()
215
+ mock_worker.profiler = DistProfiler(rank=self.rank, config=self.config, tool_config=discrete_tool_config)
216
+ mock_worker.profiler._this_step = True
217
+
218
+ mock_mark_range = "mocked_range_handle"
219
+ mock_profile_npu = MagicMock()
220
+
221
+ with (
222
+ patch("verl.utils.profiler.mstx_profile.mark_start_range") as mock_start_patch,
223
+ patch("verl.utils.profiler.mstx_profile.mark_end_range") as mock_end_patch,
224
+ patch("verl.utils.profiler.mstx_profile.get_npu_profiler") as mock_get_profiler,
225
+ ):
226
+ mock_start_patch.return_value = mock_mark_range
227
+ mock_get_profiler.return_value = mock_profile_npu
228
+ decorator = mock_worker.profiler.annotate(message="test", role="test_role")
229
+
230
+ @decorator
231
+ def test_func(self, *args, **kwargs):
232
+ return "result"
233
+
234
+ result = test_func(mock_worker)
235
+
236
+ self.assertEqual(result, "result")
237
+ mock_start_patch.assert_called_once_with(message="test")
238
+ mock_end_patch.assert_called_once_with(mock_mark_range)
239
+ mock_get_profiler.assert_called_once_with(
240
+ contents=mock_worker.profiler._impl.profile_contents,
241
+ profile_level=mock_worker.profiler._impl.profile_level,
242
+ profile_save_path=mock_worker.profiler._impl.profile_save_path,
243
+ analysis=mock_worker.profiler._impl.analysis,
244
+ role="test_role",
245
+ )
246
+ mock_profile_npu.start.assert_called_once()
247
+ mock_profile_npu.step.assert_called_once()
248
+ mock_profile_npu.stop.assert_called_once()
249
+
250
+ def test_annotate_with_default_message(self):
251
+ mock_worker = MagicMock()
252
+ mock_worker.profiler = DistProfiler(rank=self.rank, config=self.config, tool_config=self.tool_config)
253
+ mock_worker.profiler._this_step = True
254
+
255
+ mock_mark_range = "mocked_range_handle"
256
+ with (
257
+ patch("verl.utils.profiler.mstx_profile.mark_start_range") as mock_start_patch,
258
+ patch("verl.utils.profiler.mstx_profile.mark_end_range") as mock_end_patch,
259
+ ):
260
+ mock_start_patch.return_value = mock_mark_range
261
+ decorator = mock_worker.profiler.annotate()
262
+
263
+ @decorator
264
+ def test_func(self, *args, **kwargs):
265
+ return "result"
266
+
267
+ test_func(mock_worker)
268
+
269
+ mock_start_patch.assert_called_once_with(message="test_func")
270
+ mock_end_patch.assert_called_once_with(mock_mark_range)
271
+
272
+
273
+ if __name__ == "__main__":
274
+ unittest.main()
code/RL_model/verl/verl_train/tests/utils/test_temp_env_on_cpu.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+
17
+ import pytest
18
+
19
+ from verl.utils.py_functional import temp_env_var
20
+
21
+
22
+ @pytest.fixture(autouse=True)
23
+ def clean_env():
24
+ """Fixture to clean up environment variables before and after each test."""
25
+ # Store original environment state
26
+ original_env = dict(os.environ)
27
+
28
+ # Clean up any test variables that might exist
29
+ test_vars = ["TEST_VAR", "TEST_VAR_2", "EXISTING_VAR"]
30
+ for var in test_vars:
31
+ if var in os.environ:
32
+ del os.environ[var]
33
+
34
+ # Yield control to the test function
35
+ yield
36
+
37
+ # Restore original environment state after test
38
+ os.environ.clear()
39
+ os.environ.update(original_env)
40
+
41
+
42
+ def test_set_new_env_var():
43
+ """Test setting a new environment variable that didn't exist before."""
44
+ # Ensure variable doesn't exist
45
+ assert "TEST_VAR" not in os.environ
46
+
47
+ with temp_env_var("TEST_VAR", "test_value"):
48
+ # Variable should be set inside context
49
+ assert os.environ["TEST_VAR"] == "test_value"
50
+ assert "TEST_VAR" in os.environ
51
+
52
+ # Variable should be removed after context
53
+ assert "TEST_VAR" not in os.environ
54
+
55
+
56
+ def test_restore_existing_env_var():
57
+ """Test restoring an environment variable that already existed."""
58
+ # Set up existing variable
59
+ os.environ["EXISTING_VAR"] = "original_value"
60
+
61
+ with temp_env_var("EXISTING_VAR", "temporary_value"):
62
+ # Variable should be temporarily changed
63
+ assert os.environ["EXISTING_VAR"] == "temporary_value"
64
+
65
+ # Variable should be restored to original value
66
+ assert os.environ["EXISTING_VAR"] == "original_value"
67
+
68
+
69
+ def test_env_var_restored_on_exception():
70
+ """Test that environment variables are restored even when exceptions occur."""
71
+ # Set up existing variable
72
+ os.environ["EXISTING_VAR"] = "original_value"
73
+
74
+ with pytest.raises(ValueError):
75
+ with temp_env_var("EXISTING_VAR", "temporary_value"):
76
+ # Verify variable is set
77
+ assert os.environ["EXISTING_VAR"] == "temporary_value"
78
+ # Raise exception
79
+ raise ValueError("Test exception")
80
+
81
+ # Variable should still be restored despite exception
82
+ assert os.environ["EXISTING_VAR"] == "original_value"
83
+
84
+
85
+ def test_nested_context_managers():
86
+ """Test nested temp_env_var context managers."""
87
+ # Set up original variable
88
+ os.environ["TEST_VAR"] = "original"
89
+
90
+ with temp_env_var("TEST_VAR", "level1"):
91
+ assert os.environ["TEST_VAR"] == "level1"
92
+
93
+ with temp_env_var("TEST_VAR", "level2"):
94
+ assert os.environ["TEST_VAR"] == "level2"
95
+
96
+ # Should restore to level1
97
+ assert os.environ["TEST_VAR"] == "level1"
98
+
99
+ # Should restore to original
100
+ assert os.environ["TEST_VAR"] == "original"
101
+
102
+
103
+ def test_multiple_different_vars():
104
+ """Test setting multiple different environment variables."""
105
+ # Set up one existing variable
106
+ os.environ["EXISTING_VAR"] = "existing_value"
107
+
108
+ with temp_env_var("EXISTING_VAR", "modified"):
109
+ with temp_env_var("TEST_VAR", "new_value"):
110
+ assert os.environ["EXISTING_VAR"] == "modified"
111
+ assert os.environ["TEST_VAR"] == "new_value"
112
+
113
+ # Check restoration
114
+ assert os.environ["EXISTING_VAR"] == "existing_value"
115
+ assert "TEST_VAR" not in os.environ
116
+
117
+
118
+ def test_empty_string_value():
119
+ """Test setting environment variable to empty string."""
120
+ with temp_env_var("TEST_VAR", ""):
121
+ assert os.environ["TEST_VAR"] == ""
122
+ assert "TEST_VAR" in os.environ
123
+
124
+ # Should be removed after context
125
+ assert "TEST_VAR" not in os.environ
126
+
127
+
128
+ def test_overwrite_with_empty_string():
129
+ """Test overwriting existing variable with empty string."""
130
+ os.environ["EXISTING_VAR"] = "original"
131
+
132
+ with temp_env_var("EXISTING_VAR", ""):
133
+ assert os.environ["EXISTING_VAR"] == ""
134
+
135
+ # Should restore original value
136
+ assert os.environ["EXISTING_VAR"] == "original"
137
+
138
+
139
+ def test_context_manager_returns_none():
140
+ """Test that context manager yields None."""
141
+ with temp_env_var("TEST_VAR", "value") as result:
142
+ assert result is None
143
+ assert os.environ["TEST_VAR"] == "value"
code/RL_model/verl/verl_train/tests/utils/test_timeout_decorator_cpu.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import multiprocessing
16
+ import sys
17
+ import threading
18
+ import time
19
+
20
+ import pytest # Import pytest
21
+
22
+ from verl.utils.py_functional import timeout_limit as timeout
23
+
24
+ # --- Test Task Functions ---
25
+ TEST_TIMEOUT_SECONDS = 1.5 # Timeout duration for tests
26
+ LONG_TASK_DURATION = TEST_TIMEOUT_SECONDS + 0.5 # Duration slightly longer than timeout
27
+
28
+
29
+ @timeout(seconds=TEST_TIMEOUT_SECONDS) # Keep global decorator for mp tests
30
+ def quick_task(x):
31
+ """A task that completes quickly."""
32
+ time.sleep(0.1)
33
+ return "quick_ok"
34
+
35
+
36
+ @timeout(seconds=TEST_TIMEOUT_SECONDS) # Keep global decorator for mp tests
37
+ def slow_task(x):
38
+ """A task that takes longer than the timeout."""
39
+ time.sleep(LONG_TASK_DURATION)
40
+ return "slow_finished" # This return value indicates it didn't time out
41
+
42
+
43
+ # REMOVE global decorator here
44
+ def task_raises_value_error(): # Now truly not globally decorated
45
+ """A task that intentionally raises a ValueError."""
46
+ raise ValueError("Specific value error from task")
47
+
48
+
49
+ # --- Top-level function for signal test in subprocess ---
50
+ # Keep this decorated globally for the specific subprocess test case
51
+ @timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)
52
+ def top_level_decorated_quick_task_signal():
53
+ """A pickleable top-level function decorated with signal timeout."""
54
+ # Assuming this calls the logic of quick_task directly for the test purpose
55
+ time.sleep(0.1)
56
+ return "quick_ok_signal_subprocess" # Different return for clarity if needed
57
+
58
+
59
+ # --- Top-level function for signal test in subprocess ---
60
+ # Keep this decorated globally for the specific subprocess test case
61
+ @timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)
62
+ def top_level_decorated_slow_task_signal():
63
+ """A pickleable top-level function decorated with signal timeout."""
64
+ time.sleep(LONG_TASK_DURATION)
65
+ return "slow_finished"
66
+
67
+
68
+ # --- NEW: Top-level helper function to run target in process ---
69
+ def run_target_and_put_in_queue(target_func, q):
70
+ """
71
+ Top-level helper function to run a target function and put its result or exception into a queue.
72
+ This function is pickleable and can be used as the target for multiprocessing.Process.
73
+ """
74
+ try:
75
+ result = target_func()
76
+ q.put(("success", result))
77
+ except Exception as e:
78
+ q.put(("error", e))
79
+
80
+
81
+ # Use a module-level fixture to set the start method on macOS
82
+ @pytest.fixture(scope="module", autouse=True) # Changed scope to module
83
+ def set_macos_start_method():
84
+ if sys.platform == "darwin":
85
+ # Force fork method on macOS to avoid pickling issues with globally decorated functions
86
+ # when running tests via pytest discovery.
87
+ current_method = multiprocessing.get_start_method(allow_none=True)
88
+ # Only set if not already set or if set to something else (less likely in test run)
89
+ if current_method is None or current_method != "fork":
90
+ try:
91
+ multiprocessing.set_start_method("fork", force=True)
92
+ except RuntimeError:
93
+ # Might fail if context is already started, ignore in that case.
94
+ pass
95
+
96
+
97
+ def test_quick_task(): # Renamed from test_multiprocessing_quick_task
98
+ """Tests timeout handles a quick task correctly."""
99
+ # Call the globally decorated function directly
100
+ result = quick_task(1)
101
+ assert result == "quick_ok" # Use pytest assert
102
+
103
+
104
+ def test_slow_task_timeout(): # Renamed from test_multiprocessing_slow_task_timeout
105
+ """Tests timeout correctly raises TimeoutError for a slow task."""
106
+ # Call the globally decorated function directly within pytest.raises
107
+ with pytest.raises(TimeoutError) as excinfo: # Use pytest.raises
108
+ slow_task(1)
109
+ # Check the error message from the multiprocessing implementation
110
+ assert f"timed out after {TEST_TIMEOUT_SECONDS} seconds" in str(excinfo.value) # Use pytest assert
111
+
112
+
113
+ def test_internal_exception(): # Renamed from test_multiprocessing_internal_exception
114
+ """Tests timeout correctly propagates internal exceptions."""
115
+ # Apply the default timeout decorator dynamically to the undecorated function
116
+ decorated_task = timeout(seconds=TEST_TIMEOUT_SECONDS)(task_raises_value_error) # Apply decorator dynamically
117
+ with pytest.raises(ValueError) as excinfo: # Use pytest.raises
118
+ decorated_task() # Call the dynamically decorated function
119
+ assert str(excinfo.value) == "Specific value error from task" # Use pytest assert
120
+
121
+
122
+ # --- Test the signal implementation (use_signals=True) ---
123
+ # Note: As per py_functional.py, use_signals=True currently falls back to
124
+ # multiprocessing on POSIX. These tests verify that behavior.
125
+
126
+
127
+ def test_signal_quick_task_main_process(): # Removed self
128
+ """Tests signal timeout handles a quick task correctly in the main process."""
129
+
130
+ # Apply the signal decorator dynamically
131
+ def plain_quick_task_logic():
132
+ time.sleep(0.1)
133
+ return "quick_ok_signal"
134
+
135
+ decorated_task = timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)(plain_quick_task_logic)
136
+ assert decorated_task() == "quick_ok_signal" # Use pytest assert
137
+
138
+
139
+ def test_signal_slow_task_main_process_timeout(): # Removed self
140
+ """Tests signal timeout correctly raises TimeoutError for a slow task in the main process."""
141
+
142
+ # Apply the signal decorator dynamically
143
+ def plain_slow_task_logic():
144
+ time.sleep(LONG_TASK_DURATION)
145
+ return "slow_finished_signal"
146
+
147
+ decorated_task = timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)(plain_slow_task_logic)
148
+ with pytest.raises(TimeoutError) as excinfo: # Use pytest.raises
149
+ decorated_task()
150
+ # Check the error message (falls back to multiprocessing message on POSIX)
151
+ assert f"timed out after {TEST_TIMEOUT_SECONDS} seconds" in str(excinfo.value) # Use pytest assert
152
+
153
+
154
+ @pytest.mark.skip(reason="this test won't pass. Just to show why use_signals should not be used")
155
+ def test_signal_in_thread_does_not_timeout():
156
+ """
157
+ Tests that signal-based timeout does NOT work reliably in a child thread.
158
+ The TimeoutError from the signal handler is not expected to be raised.
159
+ """
160
+ result_container = [] # Use a list to store result from thread
161
+ exception_container = [] # Use a list to store exception from thread
162
+
163
+ @timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=True)
164
+ def slow_task_in_thread():
165
+ try:
166
+ print("Thread: Starting slow task...")
167
+ time.sleep(LONG_TASK_DURATION)
168
+ print("Thread: Slow task finished.")
169
+ return "slow_finished_in_thread"
170
+ except Exception as e:
171
+ # Catch any exception within the thread's target function
172
+ print(f"Thread: Caught exception: {e}")
173
+ exception_container.append(e)
174
+ return None # Indicate failure
175
+
176
+ def thread_target():
177
+ try:
178
+ # Run the decorated function inside the thread
179
+ res = slow_task_in_thread()
180
+ if res is not None:
181
+ result_container.append(res)
182
+ except Exception as e:
183
+ # This might catch exceptions happening *outside* the decorated function
184
+ # but still within the thread target, though less likely here.
185
+ print(f"Thread Target: Caught exception: {e}")
186
+ exception_container.append(e)
187
+
188
+ thread = threading.Thread(target=thread_target)
189
+ print("Main: Starting thread...")
190
+ thread.start()
191
+ # Wait longer than the timeout + task duration to ensure the thread finishes
192
+ # regardless of whether timeout worked or not.
193
+ thread.join(timeout=LONG_TASK_DURATION + 1)
194
+
195
+ assert len(exception_container) == 1
196
+ assert isinstance(exception_container[0], TimeoutError)
197
+ assert not result_container
198
+
199
+
200
+ def test_in_thread_timeout():
201
+ result_container = [] # Use a list to store result from thread
202
+ exception_container = [] # Use a list to store exception from thread
203
+
204
+ @timeout(seconds=TEST_TIMEOUT_SECONDS, use_signals=False)
205
+ def slow_task_in_thread():
206
+ try:
207
+ print("Thread: Starting slow task...")
208
+ time.sleep(LONG_TASK_DURATION)
209
+ print("Thread: Slow task finished.")
210
+ return "slow_finished_in_thread"
211
+ except Exception as e:
212
+ # Catch any exception within the thread's target function
213
+ print(f"Thread: Caught exception: {e}")
214
+ exception_container.append(e)
215
+ return None # Indicate failure
216
+
217
+ def thread_target():
218
+ try:
219
+ # Run the decorated function inside the thread
220
+ res = slow_task_in_thread()
221
+ if res is not None:
222
+ result_container.append(res)
223
+ except Exception as e:
224
+ # This might catch exceptions happening *outside* the decorated function
225
+ # but still within the thread target, though less likely here.
226
+ print(f"Thread Target: Caught exception: {e}")
227
+ exception_container.append(e)
228
+
229
+ thread = threading.Thread(target=thread_target)
230
+ print("Main: Starting thread...")
231
+ thread.start()
232
+ # Wait longer than the timeout + task duration to ensure the thread finishes
233
+ # regardless of whether timeout worked or not.
234
+ thread.join(timeout=LONG_TASK_DURATION + 1)
235
+
236
+ assert len(exception_container) == 1
237
+ assert isinstance(exception_container[0], TimeoutError)
238
+ assert not result_container
code/RL_model/verl/verl_train/tests/utils/test_torch_functional.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+
17
+ import pytest
18
+ import torch
19
+ import torch.distributed as dist
20
+ import torch.multiprocessing as mp
21
+
22
+ from verl.utils.device import get_device_name, get_nccl_backend, get_torch_device
23
+ from verl.utils.torch_functional import (
24
+ distributed_masked_mean,
25
+ distributed_mean_max_min_std,
26
+ expand_as_nested,
27
+ masked_mean,
28
+ )
29
+
30
+
31
+ def _worker_mean(rank: int, world_size: int, rendezvous_file: str):
32
+ # 1) set GPU and init NCCL
33
+ get_torch_device().set_device(rank)
34
+ dist.init_process_group(
35
+ backend=get_nccl_backend(),
36
+ init_method=f"file://{rendezvous_file}",
37
+ rank=rank,
38
+ world_size=world_size,
39
+ )
40
+ # each rank holds tensor [rank+1]
41
+ local = torch.tensor([float(rank + 1)], device=f"{get_device_name()}:{rank}")
42
+ mean, gmax, gmin, gstd = distributed_mean_max_min_std(local, True, True, True)
43
+
44
+ values = [float(i + 1) for i in range(world_size)]
45
+ exp_mean = sum(values) / len(values)
46
+ exp_max = max(values)
47
+ exp_min = min(values)
48
+ var = sum((x - exp_mean) ** 2 for x in values) / (len(values) - 1)
49
+ exp_std = var**0.5
50
+
51
+ # all ranks should see the same result
52
+ assert torch.allclose(mean.cpu(), torch.tensor(exp_mean)), f"mean@{rank}"
53
+ assert torch.allclose(gmax.cpu(), torch.tensor(exp_max)), f"max@{rank}"
54
+ assert torch.allclose(gmin.cpu(), torch.tensor(exp_min)), f"min@{rank}"
55
+ assert torch.allclose(gstd.cpu(), torch.tensor(exp_std)), f"std@{rank}"
56
+
57
+ dist.destroy_process_group()
58
+
59
+
60
+ @pytest.mark.parametrize(
61
+ "value,mask,gt",
62
+ [
63
+ ([1.0, 2.0, 3.0, 4.0], [1, 0, 0, 1], 2.5),
64
+ ([1.0, 2.0, float("nan"), 4.0], [1, 0, 0, 1], 2.5),
65
+ ([1.0, 2.0, float("nan"), 4.0], [1, 0, 1, 0], float("nan")),
66
+ ],
67
+ )
68
+ def test_masked_mean(value, mask, gt):
69
+ res = masked_mean(torch.tensor(value), torch.tensor(mask))
70
+ gt = torch.tensor(gt)
71
+ assert torch.allclose(res, gt) or (torch.isnan(res) and torch.isnan(gt))
72
+
73
+
74
+ @pytest.mark.parametrize("world_size", [2, 4])
75
+ def test_distributed_mean_max_min_std(world_size, tmp_path):
76
+ rendezvous_file = str(tmp_path / "rdzv_mean")
77
+ os.makedirs(os.path.dirname(rendezvous_file), exist_ok=True)
78
+
79
+ mp.spawn(
80
+ fn=_worker_mean,
81
+ args=(world_size, rendezvous_file),
82
+ nprocs=world_size,
83
+ join=True,
84
+ )
85
+
86
+
87
+ def _worker_mask(rank: int, world_size: int, rendezvous_file: str):
88
+ get_torch_device().set_device(rank)
89
+ dist.init_process_group(
90
+ backend=get_nccl_backend(),
91
+ init_method=f"file://{rendezvous_file}",
92
+ rank=rank,
93
+ world_size=world_size,
94
+ )
95
+
96
+ # build per‐rank tensor and mask
97
+ local_tensor = torch.tensor([rank * 2 + 1.0, rank * 2 + 2.0], device=f"{get_device_name()}:{rank}")
98
+ if rank == 0:
99
+ mask = torch.tensor([1, 0], device=f"{get_device_name()}:{rank}", dtype=torch.float32)
100
+ else:
101
+ mask = torch.tensor([0, 1], device=f"{get_device_name()}:{rank}", dtype=torch.float32)
102
+
103
+ gmean = distributed_masked_mean(local_tensor, mask)
104
+
105
+ valid_values = [1.0] + [2 * i + 2.0 for i in range(1, world_size)]
106
+ expected_mean = sum(valid_values) / len(valid_values)
107
+ assert torch.allclose(gmean.cpu(), torch.tensor(expected_mean)), f"masked_mean@{rank}"
108
+
109
+ dist.destroy_process_group()
110
+
111
+
112
+ @pytest.mark.parametrize("world_size", [2, 4])
113
+ def test_distributed_masked_mean(world_size, tmp_path):
114
+ rendezvous_file = str(tmp_path / "rdzv_mask")
115
+ os.makedirs(os.path.dirname(rendezvous_file), exist_ok=True)
116
+
117
+ mp.spawn(
118
+ fn=_worker_mask,
119
+ args=(world_size, rendezvous_file),
120
+ nprocs=world_size,
121
+ join=True,
122
+ )
123
+
124
+
125
+ def test_expand_as_nested():
126
+ a = torch.randn(2)
127
+ b = torch.randn(3)
128
+ c = torch.randn(4)
129
+ nested_tensor = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
130
+ tensor = torch.tensor([1, 2, 3])
131
+
132
+ output = expand_as_nested(tensor, nested_tensor)
133
+
134
+ assert output.values().tolist() == [1, 1, 2, 2, 2, 3, 3, 3, 3]
135
+ assert torch.all(output.offsets() == nested_tensor.offsets()).item()
136
+
137
+ # test exceptions
138
+ with pytest.raises(AssertionError):
139
+ expand_as_nested(tensor, tensor)
140
+
141
+ other_tensor = torch.tensor([1, 2, 3, 4])
142
+
143
+ with pytest.raises(AssertionError):
144
+ expand_as_nested(other_tensor, nested_tensor)
145
+
146
+ other_tensor = torch.tensor([[1, 2, 3]])
147
+
148
+ with pytest.raises(AssertionError):
149
+ expand_as_nested(other_tensor, nested_tensor)
150
+
151
+ with pytest.raises(AssertionError):
152
+ expand_as_nested(tensor, nested_tensor.unsqueeze(-1))