download
raw
11.7 kB
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for CUDA Graph compilation modes.
This module tests:
- FULL mode: entire forward pass captured as a single graph
- PIECEWISE mode: forward pass split at custom ops into multiple graphs
- Parameter handling: nn.Parameter should be ignored in graph capture
"""
import os
from typing import Optional
from unittest.mock import patch
import pytest
import torch
import torch.nn as nn
from torch.testing import assert_close
from magi_compiler.api import magi_compile, magi_register_custom_op
from magi_compiler.config import CompileConfig, CompileMode, CudaGraphMode
from magi_compiler.magi_backend.cuda_graph_mgr import CudaGraphMgr
pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@magi_register_custom_op("athena::my_split_op", infer_output_meta_fn=["x"], is_subgraph_boundary=True)
def my_split_op(x: torch.Tensor) -> torch.Tensor:
return x.clone()
@pytest.fixture(autouse=True)
def cuda_graph_test_env(tmp_path):
"""Fixture to set up a clean CUDA graph test environment."""
# Create an isolated CudaGraphMgr instance for testing
test_mgr = CudaGraphMgr()
test_mgr.cache = dict()
with patch("magi_compiler.magi_backend.cuda_graph_mgr.cuda_graph_mgr", return_value=test_mgr):
yield tmp_path, test_mgr
def _create_compile_config(cache_dir: str, cudagraph_mode: CudaGraphMode) -> CompileConfig:
"""Create a compile configuration for CUDA graph testing."""
from magi_compiler.config import get_compile_config
return CompileConfig(
compile_mode=CompileMode.MAGI_COMPILE,
backend="inductor",
cudagraph_mode=cudagraph_mode,
splitting_ops=get_compile_config().splitting_ops,
cache_root_dir=cache_dir,
)
class TestCudaGraphFullMode:
"""Tests for CudaGraphMode.FULL - entire forward as single graph."""
def test_full_mode_basic(self, cuda_graph_test_env):
"""Test basic FULL mode functionality with graph reuse."""
tmp_path, test_mgr = cuda_graph_test_env
cache_dir = os.path.join(str(tmp_path), "cache_full")
os.makedirs(cache_dir, exist_ok=True)
compile_config = _create_compile_config(cache_dir, CudaGraphMode.FULL)
with patch("magi_compiler._api.get_compile_config", return_value=compile_config), patch(
"torch.distributed.get_rank", return_value=0
):
@magi_compile(dynamic_arg_dims={"x": 0})
class FullModeModel(nn.Module):
def __init__(self, model_config: Optional[dict] = None):
super().__init__()
self.linear1 = nn.Linear(10, 20).cuda()
self.linear2 = nn.Linear(20, 5).cuda()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.relu(self.linear1(x))
return self.linear2(x)
model = FullModeModel(model_config=None).cuda()
with torch.no_grad():
# Prepare test inputs
x1 = torch.randn(4, 10).cuda()
x2 = x1.clone()
x3 = torch.randn(2, 10).cuda()
x4 = torch.randn(6, 10).cuda()
x5 = x1.clone()
import magi_compiler.magi_backend.cuda_graph_mgr as cgm
active_mgr = cgm.cuda_graph_mgr()
# First run: capture graph
output1 = model(x1)
assert output1.shape == (4, 5)
assert active_mgr.graph_count == 1
assert active_mgr.tensor_entry_count == 1
# Same shape input: reuse tensor and graph
output2 = model(x2)
assert_close(output1, output2, rtol=1e-4, atol=1e-4)
assert active_mgr.tensor_entry_count == 1
assert active_mgr.graph_count == 1
# Smaller batch: reuse tensor, new graph
output3 = model(x3)
assert output3.shape == (2, 5)
assert active_mgr.tensor_entry_count == 1
assert active_mgr.graph_count == 2
# Larger batch: expand tensor, invalidate previous graphs
output4 = model(x4)
assert output4.shape == (6, 5)
assert active_mgr.tensor_entry_count == 1
assert active_mgr.graph_count == 1
# Return to original batch size: recapture graph
output5 = model(x5)
assert_close(output1, output5, rtol=1e-4, atol=1e-4)
assert active_mgr.tensor_entry_count == 1
assert active_mgr.graph_count == 2
class TestCudaGraphPiecewiseMode:
"""Tests for CudaGraphMode.PIECEWISE - split at custom ops."""
def test_piecewise_mode_with_split_op(self, cuda_graph_test_env):
"""Test PIECEWISE mode with custom splitting ops."""
tmp_path, test_mgr = cuda_graph_test_env
cache_dir = os.path.join(str(tmp_path), "cache_piecewise")
os.makedirs(cache_dir, exist_ok=True)
compile_config = _create_compile_config(cache_dir, CudaGraphMode.PIECEWISE)
with patch("magi_compiler._api.get_compile_config", return_value=compile_config), patch(
"torch.distributed.get_rank", return_value=0
):
@magi_compile(dynamic_arg_dims={"x": 0})
class PiecewiseModel(nn.Module):
def __init__(self, *, model_config):
super().__init__()
self.linear1 = nn.Linear(10, 20).cuda()
self.linear2 = nn.Linear(20, 5).cuda()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear1(x)
x = torch.ops.athena.my_split_op(x) # Split point
x = self.linear2(x)
return x
model = PiecewiseModel(model_config=None).cuda()
with torch.no_grad():
x1 = torch.randn(4, 10).cuda()
x2 = torch.randn(2, 10).cuda()
x3 = torch.randn(6, 10).cuda()
x4 = torch.randn(4, 10).cuda()
x5 = x1.clone()
import magi_compiler.magi_backend.cuda_graph_mgr as cgm
active_mgr = cgm.cuda_graph_mgr()
# First run: capture 2 sub-graphs (before and after split op)
output1 = model(x1)
assert output1.shape == (4, 5)
assert active_mgr.tensor_entry_count == 2 # Two tensor entries for two sub-graphs
assert active_mgr.graph_count == 2
# Same input: reuse all
output2 = model(x1)
assert_close(output1, output2, rtol=1e-4, atol=1e-4)
assert active_mgr.tensor_entry_count == 2
assert active_mgr.graph_count == 2
# Smaller batch: reuse tensors, new sub-graphs
output3 = model(x2)
assert output3.shape == (2, 5)
assert active_mgr.tensor_entry_count == 2
assert active_mgr.graph_count == 4
# Larger batch: expand tensors, invalidate previous sub-graphs
output4 = model(x3)
assert output4.shape == (6, 5)
assert active_mgr.tensor_entry_count == 2
assert active_mgr.graph_count == 2
# Return to batch=4: reuse tensors, recapture sub-graphs
output5 = model(x4)
assert output5.shape == (4, 5)
assert active_mgr.tensor_entry_count == 2
assert active_mgr.graph_count == 4
# Same as first input: verify output consistency
output6 = model(x5)
assert_close(output1, output6, rtol=1e-4, atol=1e-4)
assert active_mgr.tensor_entry_count == 2
assert active_mgr.graph_count == 4
class TestCudaGraphParameterHandling:
"""Tests for nn.Parameter handling in CUDA graph capture."""
def test_parameters_excluded_from_graph_inputs(self, cuda_graph_test_env):
"""Test that nn.Parameters are not included in graph input tensors."""
tmp_path, test_mgr = cuda_graph_test_env
cache_dir = os.path.join(str(tmp_path), "cache_params")
os.makedirs(cache_dir, exist_ok=True)
compile_config = _create_compile_config(cache_dir, CudaGraphMode.FULL)
with patch("magi_compiler._api.get_compile_config", return_value=compile_config), patch(
"torch.distributed.get_rank", return_value=0
):
@magi_compile(dynamic_arg_dims={"x": 0})
class ParamModel(nn.Module):
def __init__(self, model_config: Optional[dict] = None):
super().__init__()
# Multiple nn.Parameters (simulating model weights)
self.weight1 = nn.Parameter(torch.randn(10, 20).cuda())
self.bias1 = nn.Parameter(torch.randn(20).cuda())
self.weight2 = nn.Parameter(torch.randn(20, 5).cuda())
self.bias2 = nn.Parameter(torch.randn(5).cuda())
self.nested_params = nn.ParameterList(
[nn.Parameter(torch.randn(3, 5).cuda()), nn.Parameter(torch.randn(4, 4).cuda())]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.matmul(x, self.weight1) + self.bias1
x = torch.relu(x)
x = torch.matmul(x, self.weight2) + self.bias2
x = x + torch.matmul(torch.ones_like(x[:, :3]), self.nested_params[0])
return x
model = ParamModel(model_config=None).cuda()
with torch.no_grad():
x1 = torch.randn(4, 10).cuda()
output1 = model(x1)
assert output1.shape == (4, 5)
import magi_compiler.magi_backend.cuda_graph_mgr as cgm
active_mgr = cgm.cuda_graph_mgr()
# Only 1 tensor entry (the input x), not the parameters
assert active_mgr.tensor_entry_count == 1
# Verify the static entry contains only the input tensor
static_entry = next(iter(active_mgr.cache.values()))
assert len(static_entry.input_tensors) == 1
assert len(static_entry.output_tensors) == 1
assert isinstance(static_entry.input_tensors[0], torch.Tensor)
assert not isinstance(static_entry.input_tensors[0], nn.Parameter)
# Additional verification: ArgsUtils extracts only input tensor
from magi_compiler.magi_backend.cuda_graph_mgr import ArgsUtils
input_obj = {"args": (x1,), "kwargs": {}}
extracted_tensors, _, _ = ArgsUtils.recursive_extract_core(input_obj)
assert len(extracted_tensors) == 1
assert extracted_tensors[0].data_ptr() == x1.data_ptr()
if __name__ == "__main__":
pytest.main([__file__, "-v"])

Xet Storage Details

Size:
11.7 kB
·
Xet hash:
df3f9547714c886e8b27313f599753a61f478a07b9e112b70fa1123161fe7074

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.