# Copyright (c) 2025 SandAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager from typing import Type import pytest import torch import torch.nn as nn import torch.nn.functional as F from magi_compiler import magi_compile from magi_compiler.config import OffloadPolicy, get_compile_config from .model_definition import MLPConfig, RMSNorm class TransformerWrapper(nn.Module): """ A wrapper class simulating a Transformer Block. Accepts mlp_cls to support injecting dynamically defined classes. """ def __init__(self, config: MLPConfig, mlp_cls: Type[nn.Module]): super().__init__() # Standard layer (should move to GPU) self.attention_proj = nn.Linear(config.hidden_size, config.hidden_size, dtype=config.params_dtype) # Compiled layer (should stay on CPU if offload is enabled) self.mlp = mlp_cls(config) def forward(self, x): x = self.mlp(x) x = my_attention(x, x, x) x = self.attention_proj(x) return x @contextmanager def set_cpu_offload(enable: bool, offload_policy: OffloadPolicy = OffloadPolicy.COST_EFFECTIVE): """ Context manager to temporarily override the cpu_offload setting in global config. """ config = get_compile_config() original_value = config.offload_config.model_cpu_offload config.offload_config.model_cpu_offload = enable original_offload_policy = config.offload_config.offload_policy config.offload_config.offload_policy = offload_policy try: yield finally: config.offload_config.model_cpu_offload = original_value config.offload_config.offload_policy = original_offload_policy def create_offload_mlp_class(): """ Create MLP class at runtime so that @magi_compile decorator captures the *current* config state. This is necessary because the decorator runs at class definition time. By defining the class inside a function called within `set_cpu_offload(True)` context, we ensure the decorator sees `model_cpu_offload=True`. """ @magi_compile(dynamic_arg_dims={"x": 0}) class OffloadMLP(torch.nn.Module): config: MLPConfig def __init__(self, config: MLPConfig): super().__init__() self.config = config self.pre_norm = RMSNorm(config.hidden_size) self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False, dtype=config.params_dtype) self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False, dtype=config.params_dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.pre_norm(x).to(torch.bfloat16) x = self.up_proj(x).to(torch.float32) x = F.silu(x).to(torch.bfloat16) x = self.down_proj(x) return x return OffloadMLP @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support") def test_cpu_offload_placement(device, mlp_config): """ Test that the decorated module stays on CPU when .cuda() is called on parent, while other modules move correctly. """ # Use the context manager to enable CPU offload with set_cpu_offload(True): # 1. Initialize the parent model OffloadMLP = create_offload_mlp_class() model = TransformerWrapper(mlp_config, mlp_cls=OffloadMLP) # Verify initial state (everything on CPU by default in PyTorch) assert model.attention_proj.weight.device.type == "cpu" assert model.mlp.up_proj.weight.device.type == "cpu" # 2. Move the model to GPU # This triggers the _apply hook in _magi_compile model.cuda() # 3. Verify devices # The standard layer should be on GPU assert model.attention_proj.weight.device.type == "cuda", "Standard layers should move to CUDA" # The compiled/offloaded layer should stay on CPU assert ( model.mlp.up_proj.weight.device.type == "cpu" ), "Compiled MLP layer should remain on CPU due to offload configuration" @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support") def test_cpu_offload_manual_move(device, mlp_config): """ Test that the offload hook only blocks the move ONCE. Subsequent calls to .to(device) on the specific module should allow movement. """ with set_cpu_offload(True): OffloadMLP = create_offload_mlp_class() model = TransformerWrapper(mlp_config, mlp_cls=OffloadMLP) # 1. First move (Should trigger offload logic) model.cuda() assert model.mlp.up_proj.weight.device.type == "cpu" assert model.attention_proj.weight.device.type == "cuda" # 2. Check if the internal flag is set (optional debugging check) # Note: This relies on the implementation detail _magi_offloaded_once if hasattr(model.mlp, "_magi_offloaded_once"): assert model.mlp._magi_offloaded_once is True # 3. Second move (Should bypass hook and actually move to GPU) # Manually force the submodule to GPU model.mlp.to(device) assert model.mlp.up_proj.weight.device.type == "cuda", "Subsequent .to() calls should allow moving the module to GPU" @torch.library.custom_op("athena::my_attention", mutates_args=()) def my_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: return q + k + v @my_attention.register_fake def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: return torch.empty_like(q) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support") def test_cpu_offload_inference(device, mlp_config): """ Test that the offload hook only blocks the move ONCE. Subsequent calls to .to(device) on the specific module should allow movement. """ test_shapes = [ (32, mlp_config.hidden_size), # Small batch (128, mlp_config.hidden_size), # Medium batch (512, mlp_config.hidden_size), # Large batch # NOTE: compiler will specialize for single token, so we move it to the last (1, mlp_config.hidden_size), # Single token ] with set_cpu_offload(True): get_compile_config().splitting_ops.extend(["athena::my_attention"]) OffloadMLP = create_offload_mlp_class() model = TransformerWrapper(mlp_config, mlp_cls=OffloadMLP) # 1. First move (Should trigger offload logic) model.cuda() assert model.mlp.up_proj.weight.device.type == "cpu" assert model.attention_proj.weight.device.type == "cuda" with torch.no_grad(): for num_tokens, hidden_size in test_shapes: input_tensor = torch.randn(num_tokens, hidden_size, device=device, dtype=mlp_config.params_dtype) output = model(input_tensor) assert output.shape == ( num_tokens, hidden_size, ), f"For input shape ({num_tokens}, {hidden_size}), output shape should be ({num_tokens}, {hidden_size}), but got {output.shape}" @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support") def test_cpu_offload_heuristic(device, mlp_config): """ Test that the heuristic scheduler is working correctly. """ test_shapes = [ (32, mlp_config.hidden_size), # Small batch (128, mlp_config.hidden_size), # Medium batch (512, mlp_config.hidden_size), # Large batch # NOTE: compiler will specialize for single token, so we move it to the last (1, mlp_config.hidden_size), # Single token ] with set_cpu_offload(True, OffloadPolicy.HEURISTIC): get_compile_config().splitting_ops.extend(["athena::my_attention"]) OffloadMLP = create_offload_mlp_class() model = TransformerWrapper(mlp_config, mlp_cls=OffloadMLP) model.cuda() assert model.mlp.up_proj.weight.device.type == "cpu" assert model.attention_proj.weight.device.type == "cuda" with torch.no_grad(): for num_tokens, hidden_size in test_shapes: input_tensor = torch.randn(num_tokens, hidden_size, device=device, dtype=mlp_config.params_dtype) output = model(input_tensor) assert output.shape == ( num_tokens, hidden_size, ), f"For input shape ({num_tokens}, {hidden_size}), output shape should be ({num_tokens}, {hidden_size}), but got {output.shape}" if __name__ == "__main__": pytest.main([__file__, "-v"])