# Copyright 2026 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest from unittest.mock import MagicMock, patch import torch from verl.utils.profiler.config import ProfilerConfig, TorchProfilerToolConfig from verl.utils.profiler.torch_profile import Profiler, get_torch_profiler class TestTorchProfile(unittest.TestCase): def setUp(self): # Reset Profiler class state Profiler._define_count = 0 @patch("torch.profiler.profile") def test_get_torch_profiler(self, mock_profile): # Test wrapper function get_torch_profiler(contents=["cpu", "cuda", "stack"], save_path="/tmp/test", rank=0) mock_profile.assert_called_once() _, kwargs = mock_profile.call_args # Verify activities activities = kwargs["activities"] self.assertIn(torch.profiler.ProfilerActivity.CPU, activities) self.assertIn(torch.profiler.ProfilerActivity.CUDA, activities) # Verify options self.assertTrue(kwargs["with_stack"]) self.assertFalse(kwargs["record_shapes"]) self.assertFalse(kwargs["profile_memory"]) @patch("verl.utils.profiler.torch_profile.get_torch_profiler") def test_profiler_lifecycle(self, mock_get_profiler): # Mock the underlying torch profiler object mock_prof_instance = MagicMock() mock_get_profiler.return_value = mock_prof_instance # Initialize tool_config = TorchProfilerToolConfig(contents=["cpu"], discrete=False) config = ProfilerConfig(save_path="/tmp/test", enable=True, tool_config=tool_config) profiler = Profiler(rank=0, config=config, tool_config=tool_config) # Test Start profiler.start() mock_get_profiler.assert_called_once() mock_prof_instance.start.assert_called_once() # Test Step profiler.step() mock_prof_instance.step.assert_called_once() # Test Stop profiler.stop() mock_prof_instance.stop.assert_called_once() @patch("verl.utils.profiler.torch_profile.get_torch_profiler") def test_discrete_mode(self, mock_get_profiler): # Mock for discrete mode mock_prof_instance = MagicMock() mock_get_profiler.return_value = mock_prof_instance tool_config = TorchProfilerToolConfig(contents=["cpu"], discrete=True) config = ProfilerConfig(save_path="/tmp/test", enable=True, tool_config=tool_config) profiler = Profiler(rank=0, config=config, tool_config=tool_config) # In discrete mode, start/stop shouldn't trigger global profiler immediately profiler.start() mock_get_profiler.assert_not_called() profiler.stop() mock_prof_instance.stop.assert_not_called() if __name__ == "__main__": unittest.main()