File size: 9,177 Bytes
e6066e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import contextmanager
from typing import Type

import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from magi_compiler import magi_compile
from magi_compiler.config import OffloadPolicy, get_compile_config

from .model_definition import MLPConfig, RMSNorm


class TransformerWrapper(nn.Module):
    """
    A wrapper class simulating a Transformer Block.
    Accepts mlp_cls to support injecting dynamically defined classes.
    """

    def __init__(self, config: MLPConfig, mlp_cls: Type[nn.Module]):
        super().__init__()
        # Standard layer (should move to GPU)
        self.attention_proj = nn.Linear(config.hidden_size, config.hidden_size, dtype=config.params_dtype)

        # Compiled layer (should stay on CPU if offload is enabled)
        self.mlp = mlp_cls(config)

    def forward(self, x):
        x = self.mlp(x)
        x = my_attention(x, x, x)
        x = self.attention_proj(x)
        return x


@contextmanager
def set_cpu_offload(enable: bool, offload_policy: OffloadPolicy = OffloadPolicy.COST_EFFECTIVE):
    """
    Context manager to temporarily override the cpu_offload setting in global config.
    """
    config = get_compile_config()
    original_value = config.offload_config.model_cpu_offload
    config.offload_config.model_cpu_offload = enable
    original_offload_policy = config.offload_config.offload_policy
    config.offload_config.offload_policy = offload_policy
    try:
        yield
    finally:
        config.offload_config.model_cpu_offload = original_value
        config.offload_config.offload_policy = original_offload_policy


def create_offload_mlp_class():
    """
    Create MLP class at runtime so that @magi_compile decorator captures the *current* config state.

    This is necessary because the decorator runs at class definition time.
    By defining the class inside a function called within `set_cpu_offload(True)` context,
    we ensure the decorator sees `model_cpu_offload=True`.
    """

    @magi_compile(dynamic_arg_dims={"x": 0})
    class OffloadMLP(torch.nn.Module):
        config: MLPConfig

        def __init__(self, config: MLPConfig):
            super().__init__()
            self.config = config
            self.pre_norm = RMSNorm(config.hidden_size)
            self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False, dtype=config.params_dtype)
            self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False, dtype=config.params_dtype)

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            x = self.pre_norm(x).to(torch.bfloat16)
            x = self.up_proj(x).to(torch.float32)
            x = F.silu(x).to(torch.bfloat16)
            x = self.down_proj(x)
            return x

    return OffloadMLP


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support")
def test_cpu_offload_placement(device, mlp_config):
    """
    Test that the decorated module stays on CPU when .cuda() is called on parent,
    while other modules move correctly.
    """
    # Use the context manager to enable CPU offload
    with set_cpu_offload(True):
        # 1. Initialize the parent model
        OffloadMLP = create_offload_mlp_class()

        model = TransformerWrapper(mlp_config, mlp_cls=OffloadMLP)

        # Verify initial state (everything on CPU by default in PyTorch)
        assert model.attention_proj.weight.device.type == "cpu"
        assert model.mlp.up_proj.weight.device.type == "cpu"

        # 2. Move the model to GPU
        # This triggers the _apply hook in _magi_compile
        model.cuda()

        # 3. Verify devices
        # The standard layer should be on GPU
        assert model.attention_proj.weight.device.type == "cuda", "Standard layers should move to CUDA"

        # The compiled/offloaded layer should stay on CPU
        assert (
            model.mlp.up_proj.weight.device.type == "cpu"
        ), "Compiled MLP layer should remain on CPU due to offload configuration"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support")
def test_cpu_offload_manual_move(device, mlp_config):
    """
    Test that the offload hook only blocks the move ONCE.
    Subsequent calls to .to(device) on the specific module should allow movement.
    """
    with set_cpu_offload(True):
        OffloadMLP = create_offload_mlp_class()

        model = TransformerWrapper(mlp_config, mlp_cls=OffloadMLP)

        # 1. First move (Should trigger offload logic)
        model.cuda()
        assert model.mlp.up_proj.weight.device.type == "cpu"
        assert model.attention_proj.weight.device.type == "cuda"

        # 2. Check if the internal flag is set (optional debugging check)
        # Note: This relies on the implementation detail _magi_offloaded_once
        if hasattr(model.mlp, "_magi_offloaded_once"):
            assert model.mlp._magi_offloaded_once is True

        # 3. Second move (Should bypass hook and actually move to GPU)
        # Manually force the submodule to GPU
        model.mlp.to(device)

        assert model.mlp.up_proj.weight.device.type == "cuda", "Subsequent .to() calls should allow moving the module to GPU"


@torch.library.custom_op("athena::my_attention", mutates_args=())
def my_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
    return q + k + v


@my_attention.register_fake
def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
    return torch.empty_like(q)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support")
def test_cpu_offload_inference(device, mlp_config):
    """
    Test that the offload hook only blocks the move ONCE.
    Subsequent calls to .to(device) on the specific module should allow movement.
    """

    test_shapes = [
        (32, mlp_config.hidden_size),  # Small batch
        (128, mlp_config.hidden_size),  # Medium batch
        (512, mlp_config.hidden_size),  # Large batch
        # NOTE: compiler will specialize for single token, so we move it to the last
        (1, mlp_config.hidden_size),  # Single token
    ]
    with set_cpu_offload(True):
        get_compile_config().splitting_ops.extend(["athena::my_attention"])

        OffloadMLP = create_offload_mlp_class()

        model = TransformerWrapper(mlp_config, mlp_cls=OffloadMLP)

        # 1. First move (Should trigger offload logic)
        model.cuda()
        assert model.mlp.up_proj.weight.device.type == "cpu"
        assert model.attention_proj.weight.device.type == "cuda"

        with torch.no_grad():
            for num_tokens, hidden_size in test_shapes:
                input_tensor = torch.randn(num_tokens, hidden_size, device=device, dtype=mlp_config.params_dtype)
                output = model(input_tensor)

                assert output.shape == (
                    num_tokens,
                    hidden_size,
                ), f"For input shape ({num_tokens}, {hidden_size}), output shape should be ({num_tokens}, {hidden_size}), but got {output.shape}"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support")
def test_cpu_offload_heuristic(device, mlp_config):
    """
    Test that the heuristic scheduler is working correctly.
    """
    test_shapes = [
        (32, mlp_config.hidden_size),  # Small batch
        (128, mlp_config.hidden_size),  # Medium batch
        (512, mlp_config.hidden_size),  # Large batch
        # NOTE: compiler will specialize for single token, so we move it to the last
        (1, mlp_config.hidden_size),  # Single token
    ]
    with set_cpu_offload(True, OffloadPolicy.HEURISTIC):
        get_compile_config().splitting_ops.extend(["athena::my_attention"])
        OffloadMLP = create_offload_mlp_class()
        model = TransformerWrapper(mlp_config, mlp_cls=OffloadMLP)
        model.cuda()
        assert model.mlp.up_proj.weight.device.type == "cpu"
        assert model.attention_proj.weight.device.type == "cuda"

        with torch.no_grad():
            for num_tokens, hidden_size in test_shapes:
                input_tensor = torch.randn(num_tokens, hidden_size, device=device, dtype=mlp_config.params_dtype)
                output = model(input_tensor)

                assert output.shape == (
                    num_tokens,
                    hidden_size,
                ), f"For input shape ({num_tokens}, {hidden_size}), output shape should be ({num_tokens}, {hidden_size}), but got {output.shape}"


if __name__ == "__main__":
    pytest.main([__file__, "-v"])