Spaces:
Runtime error
Runtime error
| # Copyright (c) 2025 SandAI. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from contextlib import contextmanager | |
| from typing import Type | |
| import pytest | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from magi_compiler import magi_compile | |
| from magi_compiler.config import OffloadPolicy, get_compile_config | |
| from .model_definition import MLPConfig, RMSNorm | |
| class TransformerWrapper(nn.Module): | |
| """ | |
| A wrapper class simulating a Transformer Block. | |
| Accepts mlp_cls to support injecting dynamically defined classes. | |
| """ | |
| def __init__(self, config: MLPConfig, mlp_cls: Type[nn.Module]): | |
| super().__init__() | |
| # Standard layer (should move to GPU) | |
| self.attention_proj = nn.Linear(config.hidden_size, config.hidden_size, dtype=config.params_dtype) | |
| # Compiled layer (should stay on CPU if offload is enabled) | |
| self.mlp = mlp_cls(config) | |
| def forward(self, x): | |
| x = self.mlp(x) | |
| x = my_attention(x, x, x) | |
| x = self.attention_proj(x) | |
| return x | |
| def set_cpu_offload(enable: bool, offload_policy: OffloadPolicy = OffloadPolicy.COST_EFFECTIVE): | |
| """ | |
| Context manager to temporarily override the cpu_offload setting in global config. | |
| """ | |
| config = get_compile_config() | |
| original_value = config.offload_config.model_cpu_offload | |
| config.offload_config.model_cpu_offload = enable | |
| original_offload_policy = config.offload_config.offload_policy | |
| config.offload_config.offload_policy = offload_policy | |
| try: | |
| yield | |
| finally: | |
| config.offload_config.model_cpu_offload = original_value | |
| config.offload_config.offload_policy = original_offload_policy | |
| def create_offload_mlp_class(): | |
| """ | |
| Create MLP class at runtime so that @magi_compile decorator captures the *current* config state. | |
| This is necessary because the decorator runs at class definition time. | |
| By defining the class inside a function called within `set_cpu_offload(True)` context, | |
| we ensure the decorator sees `model_cpu_offload=True`. | |
| """ | |
| class OffloadMLP(torch.nn.Module): | |
| config: MLPConfig | |
| def __init__(self, config: MLPConfig): | |
| super().__init__() | |
| self.config = config | |
| self.pre_norm = RMSNorm(config.hidden_size) | |
| self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False, dtype=config.params_dtype) | |
| self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False, dtype=config.params_dtype) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.pre_norm(x).to(torch.bfloat16) | |
| x = self.up_proj(x).to(torch.float32) | |
| x = F.silu(x).to(torch.bfloat16) | |
| x = self.down_proj(x) | |
| return x | |
| return OffloadMLP | |
| def test_cpu_offload_placement(device, mlp_config): | |
| """ | |
| Test that the decorated module stays on CPU when .cuda() is called on parent, | |
| while other modules move correctly. | |
| """ | |
| # Use the context manager to enable CPU offload | |
| with set_cpu_offload(True): | |
| # 1. Initialize the parent model | |
| OffloadMLP = create_offload_mlp_class() | |
| model = TransformerWrapper(mlp_config, mlp_cls=OffloadMLP) | |
| # Verify initial state (everything on CPU by default in PyTorch) | |
| assert model.attention_proj.weight.device.type == "cpu" | |
| assert model.mlp.up_proj.weight.device.type == "cpu" | |
| # 2. Move the model to GPU | |
| # This triggers the _apply hook in _magi_compile | |
| model.cuda() | |
| # 3. Verify devices | |
| # The standard layer should be on GPU | |
| assert model.attention_proj.weight.device.type == "cuda", "Standard layers should move to CUDA" | |
| # The compiled/offloaded layer should stay on CPU | |
| assert ( | |
| model.mlp.up_proj.weight.device.type == "cpu" | |
| ), "Compiled MLP layer should remain on CPU due to offload configuration" | |
| def test_cpu_offload_manual_move(device, mlp_config): | |
| """ | |
| Test that the offload hook only blocks the move ONCE. | |
| Subsequent calls to .to(device) on the specific module should allow movement. | |
| """ | |
| with set_cpu_offload(True): | |
| OffloadMLP = create_offload_mlp_class() | |
| model = TransformerWrapper(mlp_config, mlp_cls=OffloadMLP) | |
| # 1. First move (Should trigger offload logic) | |
| model.cuda() | |
| assert model.mlp.up_proj.weight.device.type == "cpu" | |
| assert model.attention_proj.weight.device.type == "cuda" | |
| # 2. Check if the internal flag is set (optional debugging check) | |
| # Note: This relies on the implementation detail _magi_offloaded_once | |
| if hasattr(model.mlp, "_magi_offloaded_once"): | |
| assert model.mlp._magi_offloaded_once is True | |
| # 3. Second move (Should bypass hook and actually move to GPU) | |
| # Manually force the submodule to GPU | |
| model.mlp.to(device) | |
| assert model.mlp.up_proj.weight.device.type == "cuda", "Subsequent .to() calls should allow moving the module to GPU" | |
| def my_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: | |
| return q + k + v | |
| def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: | |
| return torch.empty_like(q) | |
| def test_cpu_offload_inference(device, mlp_config): | |
| """ | |
| Test that the offload hook only blocks the move ONCE. | |
| Subsequent calls to .to(device) on the specific module should allow movement. | |
| """ | |
| test_shapes = [ | |
| (32, mlp_config.hidden_size), # Small batch | |
| (128, mlp_config.hidden_size), # Medium batch | |
| (512, mlp_config.hidden_size), # Large batch | |
| # NOTE: compiler will specialize for single token, so we move it to the last | |
| (1, mlp_config.hidden_size), # Single token | |
| ] | |
| with set_cpu_offload(True): | |
| get_compile_config().splitting_ops.extend(["athena::my_attention"]) | |
| OffloadMLP = create_offload_mlp_class() | |
| model = TransformerWrapper(mlp_config, mlp_cls=OffloadMLP) | |
| # 1. First move (Should trigger offload logic) | |
| model.cuda() | |
| assert model.mlp.up_proj.weight.device.type == "cpu" | |
| assert model.attention_proj.weight.device.type == "cuda" | |
| with torch.no_grad(): | |
| for num_tokens, hidden_size in test_shapes: | |
| input_tensor = torch.randn(num_tokens, hidden_size, device=device, dtype=mlp_config.params_dtype) | |
| output = model(input_tensor) | |
| assert output.shape == ( | |
| num_tokens, | |
| hidden_size, | |
| ), f"For input shape ({num_tokens}, {hidden_size}), output shape should be ({num_tokens}, {hidden_size}), but got {output.shape}" | |
| def test_cpu_offload_heuristic(device, mlp_config): | |
| """ | |
| Test that the heuristic scheduler is working correctly. | |
| """ | |
| test_shapes = [ | |
| (32, mlp_config.hidden_size), # Small batch | |
| (128, mlp_config.hidden_size), # Medium batch | |
| (512, mlp_config.hidden_size), # Large batch | |
| # NOTE: compiler will specialize for single token, so we move it to the last | |
| (1, mlp_config.hidden_size), # Single token | |
| ] | |
| with set_cpu_offload(True, OffloadPolicy.HEURISTIC): | |
| get_compile_config().splitting_ops.extend(["athena::my_attention"]) | |
| OffloadMLP = create_offload_mlp_class() | |
| model = TransformerWrapper(mlp_config, mlp_cls=OffloadMLP) | |
| model.cuda() | |
| assert model.mlp.up_proj.weight.device.type == "cpu" | |
| assert model.attention_proj.weight.device.type == "cuda" | |
| with torch.no_grad(): | |
| for num_tokens, hidden_size in test_shapes: | |
| input_tensor = torch.randn(num_tokens, hidden_size, device=device, dtype=mlp_config.params_dtype) | |
| output = model(input_tensor) | |
| assert output.shape == ( | |
| num_tokens, | |
| hidden_size, | |
| ), f"For input shape ({num_tokens}, {hidden_size}), output shape should be ({num_tokens}, {hidden_size}), but got {output.shape}" | |
| if __name__ == "__main__": | |
| pytest.main([__file__, "-v"]) | |