Buckets:
| # Copyright (c) 2025 SandAI. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| Tests for CUDA Graph compilation modes. | |
| This module tests: | |
| - FULL mode: entire forward pass captured as a single graph | |
| - PIECEWISE mode: forward pass split at custom ops into multiple graphs | |
| - Parameter handling: nn.Parameter should be ignored in graph capture | |
| """ | |
| import os | |
| from typing import Optional | |
| from unittest.mock import patch | |
| import pytest | |
| import torch | |
| import torch.nn as nn | |
| from torch.testing import assert_close | |
| from magi_compiler.api import magi_compile, magi_register_custom_op | |
| from magi_compiler.config import CompileConfig, CompileMode, CudaGraphMode | |
| from magi_compiler.magi_backend.cuda_graph_mgr import CudaGraphMgr | |
| pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | |
| def my_split_op(x: torch.Tensor) -> torch.Tensor: | |
| return x.clone() | |
| def cuda_graph_test_env(tmp_path): | |
| """Fixture to set up a clean CUDA graph test environment.""" | |
| # Create an isolated CudaGraphMgr instance for testing | |
| test_mgr = CudaGraphMgr() | |
| test_mgr.cache = dict() | |
| with patch("magi_compiler.magi_backend.cuda_graph_mgr.cuda_graph_mgr", return_value=test_mgr): | |
| yield tmp_path, test_mgr | |
| def _create_compile_config(cache_dir: str, cudagraph_mode: CudaGraphMode) -> CompileConfig: | |
| """Create a compile configuration for CUDA graph testing.""" | |
| from magi_compiler.config import get_compile_config | |
| return CompileConfig( | |
| compile_mode=CompileMode.MAGI_COMPILE, | |
| backend="inductor", | |
| cudagraph_mode=cudagraph_mode, | |
| splitting_ops=get_compile_config().splitting_ops, | |
| cache_root_dir=cache_dir, | |
| ) | |
| class TestCudaGraphFullMode: | |
| """Tests for CudaGraphMode.FULL - entire forward as single graph.""" | |
| def test_full_mode_basic(self, cuda_graph_test_env): | |
| """Test basic FULL mode functionality with graph reuse.""" | |
| tmp_path, test_mgr = cuda_graph_test_env | |
| cache_dir = os.path.join(str(tmp_path), "cache_full") | |
| os.makedirs(cache_dir, exist_ok=True) | |
| compile_config = _create_compile_config(cache_dir, CudaGraphMode.FULL) | |
| with patch("magi_compiler._api.get_compile_config", return_value=compile_config), patch( | |
| "torch.distributed.get_rank", return_value=0 | |
| ): | |
| class FullModeModel(nn.Module): | |
| def __init__(self, model_config: Optional[dict] = None): | |
| super().__init__() | |
| self.linear1 = nn.Linear(10, 20).cuda() | |
| self.linear2 = nn.Linear(20, 5).cuda() | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = torch.relu(self.linear1(x)) | |
| return self.linear2(x) | |
| model = FullModeModel(model_config=None).cuda() | |
| with torch.no_grad(): | |
| # Prepare test inputs | |
| x1 = torch.randn(4, 10).cuda() | |
| x2 = x1.clone() | |
| x3 = torch.randn(2, 10).cuda() | |
| x4 = torch.randn(6, 10).cuda() | |
| x5 = x1.clone() | |
| import magi_compiler.magi_backend.cuda_graph_mgr as cgm | |
| active_mgr = cgm.cuda_graph_mgr() | |
| # First run: capture graph | |
| output1 = model(x1) | |
| assert output1.shape == (4, 5) | |
| assert active_mgr.graph_count == 1 | |
| assert active_mgr.tensor_entry_count == 1 | |
| # Same shape input: reuse tensor and graph | |
| output2 = model(x2) | |
| assert_close(output1, output2, rtol=1e-4, atol=1e-4) | |
| assert active_mgr.tensor_entry_count == 1 | |
| assert active_mgr.graph_count == 1 | |
| # Smaller batch: reuse tensor, new graph | |
| output3 = model(x3) | |
| assert output3.shape == (2, 5) | |
| assert active_mgr.tensor_entry_count == 1 | |
| assert active_mgr.graph_count == 2 | |
| # Larger batch: expand tensor, invalidate previous graphs | |
| output4 = model(x4) | |
| assert output4.shape == (6, 5) | |
| assert active_mgr.tensor_entry_count == 1 | |
| assert active_mgr.graph_count == 1 | |
| # Return to original batch size: recapture graph | |
| output5 = model(x5) | |
| assert_close(output1, output5, rtol=1e-4, atol=1e-4) | |
| assert active_mgr.tensor_entry_count == 1 | |
| assert active_mgr.graph_count == 2 | |
| class TestCudaGraphPiecewiseMode: | |
| """Tests for CudaGraphMode.PIECEWISE - split at custom ops.""" | |
| def test_piecewise_mode_with_split_op(self, cuda_graph_test_env): | |
| """Test PIECEWISE mode with custom splitting ops.""" | |
| tmp_path, test_mgr = cuda_graph_test_env | |
| cache_dir = os.path.join(str(tmp_path), "cache_piecewise") | |
| os.makedirs(cache_dir, exist_ok=True) | |
| compile_config = _create_compile_config(cache_dir, CudaGraphMode.PIECEWISE) | |
| with patch("magi_compiler._api.get_compile_config", return_value=compile_config), patch( | |
| "torch.distributed.get_rank", return_value=0 | |
| ): | |
| class PiecewiseModel(nn.Module): | |
| def __init__(self, *, model_config): | |
| super().__init__() | |
| self.linear1 = nn.Linear(10, 20).cuda() | |
| self.linear2 = nn.Linear(20, 5).cuda() | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.linear1(x) | |
| x = torch.ops.athena.my_split_op(x) # Split point | |
| x = self.linear2(x) | |
| return x | |
| model = PiecewiseModel(model_config=None).cuda() | |
| with torch.no_grad(): | |
| x1 = torch.randn(4, 10).cuda() | |
| x2 = torch.randn(2, 10).cuda() | |
| x3 = torch.randn(6, 10).cuda() | |
| x4 = torch.randn(4, 10).cuda() | |
| x5 = x1.clone() | |
| import magi_compiler.magi_backend.cuda_graph_mgr as cgm | |
| active_mgr = cgm.cuda_graph_mgr() | |
| # First run: capture 2 sub-graphs (before and after split op) | |
| output1 = model(x1) | |
| assert output1.shape == (4, 5) | |
| assert active_mgr.tensor_entry_count == 2 # Two tensor entries for two sub-graphs | |
| assert active_mgr.graph_count == 2 | |
| # Same input: reuse all | |
| output2 = model(x1) | |
| assert_close(output1, output2, rtol=1e-4, atol=1e-4) | |
| assert active_mgr.tensor_entry_count == 2 | |
| assert active_mgr.graph_count == 2 | |
| # Smaller batch: reuse tensors, new sub-graphs | |
| output3 = model(x2) | |
| assert output3.shape == (2, 5) | |
| assert active_mgr.tensor_entry_count == 2 | |
| assert active_mgr.graph_count == 4 | |
| # Larger batch: expand tensors, invalidate previous sub-graphs | |
| output4 = model(x3) | |
| assert output4.shape == (6, 5) | |
| assert active_mgr.tensor_entry_count == 2 | |
| assert active_mgr.graph_count == 2 | |
| # Return to batch=4: reuse tensors, recapture sub-graphs | |
| output5 = model(x4) | |
| assert output5.shape == (4, 5) | |
| assert active_mgr.tensor_entry_count == 2 | |
| assert active_mgr.graph_count == 4 | |
| # Same as first input: verify output consistency | |
| output6 = model(x5) | |
| assert_close(output1, output6, rtol=1e-4, atol=1e-4) | |
| assert active_mgr.tensor_entry_count == 2 | |
| assert active_mgr.graph_count == 4 | |
| class TestCudaGraphParameterHandling: | |
| """Tests for nn.Parameter handling in CUDA graph capture.""" | |
| def test_parameters_excluded_from_graph_inputs(self, cuda_graph_test_env): | |
| """Test that nn.Parameters are not included in graph input tensors.""" | |
| tmp_path, test_mgr = cuda_graph_test_env | |
| cache_dir = os.path.join(str(tmp_path), "cache_params") | |
| os.makedirs(cache_dir, exist_ok=True) | |
| compile_config = _create_compile_config(cache_dir, CudaGraphMode.FULL) | |
| with patch("magi_compiler._api.get_compile_config", return_value=compile_config), patch( | |
| "torch.distributed.get_rank", return_value=0 | |
| ): | |
| class ParamModel(nn.Module): | |
| def __init__(self, model_config: Optional[dict] = None): | |
| super().__init__() | |
| # Multiple nn.Parameters (simulating model weights) | |
| self.weight1 = nn.Parameter(torch.randn(10, 20).cuda()) | |
| self.bias1 = nn.Parameter(torch.randn(20).cuda()) | |
| self.weight2 = nn.Parameter(torch.randn(20, 5).cuda()) | |
| self.bias2 = nn.Parameter(torch.randn(5).cuda()) | |
| self.nested_params = nn.ParameterList( | |
| [nn.Parameter(torch.randn(3, 5).cuda()), nn.Parameter(torch.randn(4, 4).cuda())] | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = torch.matmul(x, self.weight1) + self.bias1 | |
| x = torch.relu(x) | |
| x = torch.matmul(x, self.weight2) + self.bias2 | |
| x = x + torch.matmul(torch.ones_like(x[:, :3]), self.nested_params[0]) | |
| return x | |
| model = ParamModel(model_config=None).cuda() | |
| with torch.no_grad(): | |
| x1 = torch.randn(4, 10).cuda() | |
| output1 = model(x1) | |
| assert output1.shape == (4, 5) | |
| import magi_compiler.magi_backend.cuda_graph_mgr as cgm | |
| active_mgr = cgm.cuda_graph_mgr() | |
| # Only 1 tensor entry (the input x), not the parameters | |
| assert active_mgr.tensor_entry_count == 1 | |
| # Verify the static entry contains only the input tensor | |
| static_entry = next(iter(active_mgr.cache.values())) | |
| assert len(static_entry.input_tensors) == 1 | |
| assert len(static_entry.output_tensors) == 1 | |
| assert isinstance(static_entry.input_tensors[0], torch.Tensor) | |
| assert not isinstance(static_entry.input_tensors[0], nn.Parameter) | |
| # Additional verification: ArgsUtils extracts only input tensor | |
| from magi_compiler.magi_backend.cuda_graph_mgr import ArgsUtils | |
| input_obj = {"args": (x1,), "kwargs": {}} | |
| extracted_tensors, _, _ = ArgsUtils.recursive_extract_core(input_obj) | |
| assert len(extracted_tensors) == 1 | |
| assert extracted_tensors[0].data_ptr() == x1.data_ptr() | |
| if __name__ == "__main__": | |
| pytest.main([__file__, "-v"]) | |
Xet Storage Details
- Size:
- 11.7 kB
- Xet hash:
- df3f9547714c886e8b27313f599753a61f478a07b9e112b70fa1123161fe7074
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.