| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import unittest |
| from unittest.mock import MagicMock, patch |
|
|
| import torch |
|
|
| from verl.utils.profiler.config import ProfilerConfig, TorchProfilerToolConfig |
| from verl.utils.profiler.torch_profile import Profiler, get_torch_profiler |
|
|
|
|
| class TestTorchProfile(unittest.TestCase): |
| def setUp(self): |
| |
| Profiler._define_count = 0 |
|
|
| @patch("torch.profiler.profile") |
| def test_get_torch_profiler(self, mock_profile): |
| |
| get_torch_profiler(contents=["cpu", "cuda", "stack"], save_path="/tmp/test", rank=0) |
| mock_profile.assert_called_once() |
| _, kwargs = mock_profile.call_args |
|
|
| |
| activities = kwargs["activities"] |
| self.assertIn(torch.profiler.ProfilerActivity.CPU, activities) |
| self.assertIn(torch.profiler.ProfilerActivity.CUDA, activities) |
|
|
| |
| self.assertTrue(kwargs["with_stack"]) |
| self.assertFalse(kwargs["record_shapes"]) |
| self.assertFalse(kwargs["profile_memory"]) |
|
|
| @patch("verl.utils.profiler.torch_profile.get_torch_profiler") |
| def test_profiler_lifecycle(self, mock_get_profiler): |
| |
| mock_prof_instance = MagicMock() |
| mock_get_profiler.return_value = mock_prof_instance |
|
|
| |
| tool_config = TorchProfilerToolConfig(contents=["cpu"], discrete=False) |
| config = ProfilerConfig(save_path="/tmp/test", enable=True, tool_config=tool_config) |
| profiler = Profiler(rank=0, config=config, tool_config=tool_config) |
|
|
| |
| profiler.start() |
| mock_get_profiler.assert_called_once() |
| mock_prof_instance.start.assert_called_once() |
|
|
| |
| profiler.step() |
| mock_prof_instance.step.assert_called_once() |
|
|
| |
| profiler.stop() |
| mock_prof_instance.stop.assert_called_once() |
|
|
| @patch("verl.utils.profiler.torch_profile.get_torch_profiler") |
| def test_discrete_mode(self, mock_get_profiler): |
| |
| mock_prof_instance = MagicMock() |
| mock_get_profiler.return_value = mock_prof_instance |
|
|
| tool_config = TorchProfilerToolConfig(contents=["cpu"], discrete=True) |
| config = ProfilerConfig(save_path="/tmp/test", enable=True, tool_config=tool_config) |
| profiler = Profiler(rank=0, config=config, tool_config=tool_config) |
|
|
| |
| profiler.start() |
| mock_get_profiler.assert_not_called() |
|
|
| profiler.stop() |
| mock_prof_instance.stop.assert_not_called() |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|