File size: 11,983 Bytes
e6066e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Tests for CUDA Graph compilation modes.

This module tests:
- FULL mode: entire forward pass captured as a single graph
- PIECEWISE mode: forward pass split at custom ops into multiple graphs
- Parameter handling: nn.Parameter should be ignored in graph capture
"""

import os
from typing import Optional
from unittest.mock import patch

import pytest
import torch
import torch.nn as nn
from magi_compiler.api import magi_compile
from magi_compiler.config import CompileConfig, CompileMode, CudaGraphMode
from magi_compiler.cuda_graph_mgr import CudaGraphMgr
from torch.testing import assert_close

pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")


def _register_custom_split_op():
    """Register a custom split op for piecewise graph testing."""
    try:
        torch.ops.athena.my_split_op
        return
    except AttributeError:
        pass

    @torch.library.custom_op("athena::my_split_op", mutates_args=())
    def my_split_op(x: torch.Tensor) -> torch.Tensor:
        return x.clone()

    @my_split_op.register_fake
    def _(x: torch.Tensor) -> torch.Tensor:
        return torch.empty_like(x)


@pytest.fixture(autouse=True)
def cuda_graph_test_env(tmp_path):
    """Fixture to set up a clean CUDA graph test environment."""
    _register_custom_split_op()

    # Create an isolated CudaGraphMgr instance for testing
    test_mgr = CudaGraphMgr()
    test_mgr.cache = dict()

    with patch("magi_compiler.cuda_graph_mgr.cuda_graph_mgr", return_value=test_mgr):
        yield tmp_path, test_mgr


def _create_compile_config(cache_dir: str, cudagraph_mode: CudaGraphMode, splitting_ops: list = None) -> CompileConfig:
    """Create a compile configuration for CUDA graph testing."""
    return CompileConfig(
        compile_mode=CompileMode.MAGI_COMPILE,
        backend="inductor",
        cudagraph_mode=cudagraph_mode,
        cudagraph_copy_inputs=True,
        splitting_ops=splitting_ops or [],
        cache_root_dir=cache_dir,
        dynamic_sources="",
        traced_files=set(),
    )


class TestCudaGraphFullMode:
    """Tests for CudaGraphMode.FULL - entire forward as single graph."""

    def test_full_mode_basic(self, cuda_graph_test_env):
        """Test basic FULL mode functionality with graph reuse."""
        tmp_path, test_mgr = cuda_graph_test_env
        cache_dir = os.path.join(str(tmp_path), "cache_full")
        os.makedirs(cache_dir, exist_ok=True)

        compile_config = _create_compile_config(cache_dir, CudaGraphMode.FULL)

        with patch("magi_compiler.api.get_compile_config", return_value=compile_config), patch(
            "torch.distributed.get_rank", return_value=0
        ):

            @magi_compile(dynamic_arg_dims={"x": 0})
            class FullModeModel(nn.Module):
                def __init__(self, model_config: Optional[dict] = None):
                    super().__init__()
                    self.linear1 = nn.Linear(10, 20).cuda()
                    self.linear2 = nn.Linear(20, 5).cuda()

                def forward(self, x: torch.Tensor) -> torch.Tensor:
                    x = torch.relu(self.linear1(x))
                    return self.linear2(x)

            model = FullModeModel(model_config=None).cuda()

            with torch.no_grad():
                # Prepare test inputs
                x1 = torch.randn(4, 10).cuda()
                x2 = x1.clone()
                x3 = torch.randn(2, 10).cuda()
                x4 = torch.randn(6, 10).cuda()
                x5 = x1.clone()

                import magi_compiler.cuda_graph_mgr as cgm

                active_mgr = cgm.cuda_graph_mgr()

                # First run: capture graph
                output1 = model(x1)
                assert output1.shape == (4, 5)
                assert active_mgr.graph_count == 1
                assert active_mgr.tensor_entry_count == 1

                # Same shape input: reuse tensor and graph
                output2 = model(x2)
                assert_close(output1, output2, rtol=1e-4, atol=1e-4)
                assert active_mgr.tensor_entry_count == 1
                assert active_mgr.graph_count == 1

                # Smaller batch: reuse tensor, new graph
                output3 = model(x3)
                assert output3.shape == (2, 5)
                assert active_mgr.tensor_entry_count == 1
                assert active_mgr.graph_count == 2

                # Larger batch: expand tensor, invalidate previous graphs
                output4 = model(x4)
                assert output4.shape == (6, 5)
                assert active_mgr.tensor_entry_count == 1
                assert active_mgr.graph_count == 1

                # Return to original batch size: recapture graph
                output5 = model(x5)
                assert_close(output1, output5, rtol=1e-4, atol=1e-4)
                assert active_mgr.tensor_entry_count == 1
                assert active_mgr.graph_count == 2


class TestCudaGraphPiecewiseMode:
    """Tests for CudaGraphMode.PIECEWISE - split at custom ops."""

    def test_piecewise_mode_with_split_op(self, cuda_graph_test_env):
        """Test PIECEWISE mode with custom splitting ops."""
        tmp_path, test_mgr = cuda_graph_test_env
        cache_dir = os.path.join(str(tmp_path), "cache_piecewise")
        os.makedirs(cache_dir, exist_ok=True)

        compile_config = _create_compile_config(cache_dir, CudaGraphMode.PIECEWISE, splitting_ops=["athena::my_split_op"])

        with patch("magi_compiler.api.get_compile_config", return_value=compile_config), patch(
            "torch.distributed.get_rank", return_value=0
        ):

            @magi_compile(dynamic_arg_dims={"x": 0})
            class PiecewiseModel(nn.Module):
                def __init__(self, *, model_config):
                    super().__init__()
                    self.linear1 = nn.Linear(10, 20).cuda()
                    self.linear2 = nn.Linear(20, 5).cuda()

                def forward(self, x: torch.Tensor) -> torch.Tensor:
                    x = self.linear1(x)
                    x = torch.ops.athena.my_split_op(x)  # Split point
                    x = self.linear2(x)
                    return x

            model = PiecewiseModel(model_config=None).cuda()

            with torch.no_grad():
                x1 = torch.randn(4, 10).cuda()
                x2 = torch.randn(2, 10).cuda()
                x3 = torch.randn(6, 10).cuda()
                x4 = torch.randn(4, 10).cuda()
                x5 = x1.clone()

                import magi_compiler.cuda_graph_mgr as cgm

                active_mgr = cgm.cuda_graph_mgr()

                # First run: capture 2 sub-graphs (before and after split op)
                output1 = model(x1)
                assert output1.shape == (4, 5)
                assert active_mgr.tensor_entry_count == 2  # Two tensor entries for two sub-graphs
                assert active_mgr.graph_count == 2

                # Same input: reuse all
                output2 = model(x1)
                assert_close(output1, output2, rtol=1e-4, atol=1e-4)
                assert active_mgr.tensor_entry_count == 2
                assert active_mgr.graph_count == 2

                # Smaller batch: reuse tensors, new sub-graphs
                output3 = model(x2)
                assert output3.shape == (2, 5)
                assert active_mgr.tensor_entry_count == 2
                assert active_mgr.graph_count == 4

                # Larger batch: expand tensors, invalidate previous sub-graphs
                output4 = model(x3)
                assert output4.shape == (6, 5)
                assert active_mgr.tensor_entry_count == 2
                assert active_mgr.graph_count == 2

                # Return to batch=4: reuse tensors, recapture sub-graphs
                output5 = model(x4)
                assert output5.shape == (4, 5)
                assert active_mgr.tensor_entry_count == 2
                assert active_mgr.graph_count == 4

                # Same as first input: verify output consistency
                output6 = model(x5)
                assert_close(output1, output6, rtol=1e-4, atol=1e-4)
                assert active_mgr.tensor_entry_count == 2
                assert active_mgr.graph_count == 4


class TestCudaGraphParameterHandling:
    """Tests for nn.Parameter handling in CUDA graph capture."""

    def test_parameters_excluded_from_graph_inputs(self, cuda_graph_test_env):
        """Test that nn.Parameters are not included in graph input tensors."""
        tmp_path, test_mgr = cuda_graph_test_env
        cache_dir = os.path.join(str(tmp_path), "cache_params")
        os.makedirs(cache_dir, exist_ok=True)

        compile_config = _create_compile_config(cache_dir, CudaGraphMode.FULL)

        with patch("magi_compiler.api.get_compile_config", return_value=compile_config), patch(
            "torch.distributed.get_rank", return_value=0
        ):

            @magi_compile(dynamic_arg_dims={"x": 0})
            class ParamModel(nn.Module):
                def __init__(self, model_config: Optional[dict] = None):
                    super().__init__()
                    # Multiple nn.Parameters (simulating model weights)
                    self.weight1 = nn.Parameter(torch.randn(10, 20).cuda())
                    self.bias1 = nn.Parameter(torch.randn(20).cuda())
                    self.weight2 = nn.Parameter(torch.randn(20, 5).cuda())
                    self.bias2 = nn.Parameter(torch.randn(5).cuda())
                    self.nested_params = nn.ParameterList(
                        [nn.Parameter(torch.randn(3, 5).cuda()), nn.Parameter(torch.randn(4, 4).cuda())]
                    )

                def forward(self, x: torch.Tensor) -> torch.Tensor:
                    x = torch.matmul(x, self.weight1) + self.bias1
                    x = torch.relu(x)
                    x = torch.matmul(x, self.weight2) + self.bias2
                    x = x + torch.matmul(torch.ones_like(x[:, :3]), self.nested_params[0])
                    return x

            model = ParamModel(model_config=None).cuda()

            with torch.no_grad():
                x1 = torch.randn(4, 10).cuda()
                output1 = model(x1)
                assert output1.shape == (4, 5)

                import magi_compiler.cuda_graph_mgr as cgm

                active_mgr = cgm.cuda_graph_mgr()

                # Only 1 tensor entry (the input x), not the parameters
                assert active_mgr.tensor_entry_count == 1

                # Verify the static entry contains only the input tensor
                static_entry = next(iter(active_mgr.cache.values()))
                assert len(static_entry.input_tensors) == 1
                assert len(static_entry.output_tensors) == 1
                assert isinstance(static_entry.input_tensors[0], torch.Tensor)
                assert not isinstance(static_entry.input_tensors[0], nn.Parameter)

            # Additional verification: ArgsUtils extracts only input tensor
            from magi_compiler.cuda_graph_mgr import ArgsUtils

            input_obj = {"args": (x1,), "kwargs": {}}
            extracted_tensors, _, _ = ArgsUtils.recursive_extract_core(input_obj)
            assert len(extracted_tensors) == 1
            assert extracted_tensors[0].data_ptr() == x1.data_ptr()


if __name__ == "__main__":
    pytest.main([__file__, "-v"])