Buckets:
| # Copyright (c) 2025 SandAI. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from contextlib import contextmanager | |
| from typing import Type | |
| import pytest | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from magi_compiler import magi_compile, magi_register_custom_op | |
| from magi_compiler.config import OffloadPolicy, get_compile_config | |
| from tests.model_definition import MLPConfig, RMSNorm | |
| class TransformerWrapper(nn.Module): | |
| """ | |
| A wrapper class simulating a Transformer Block. | |
| Accepts mlp_cls to support injecting dynamically defined classes. | |
| """ | |
| def __init__(self, config: MLPConfig, mlp_cls: Type[nn.Module]): | |
| super().__init__() | |
| # Standard layer (should move to GPU) | |
| self.attention_proj = nn.Linear(config.hidden_size, config.hidden_size, dtype=config.params_dtype) | |
| # Compiled layer (should stay on CPU if offload is enabled) | |
| self.mlp = mlp_cls(config) | |
| def forward(self, x): | |
| x = self.mlp(x) | |
| x = my_attention(x, x, x) | |
| x = self.attention_proj(x) | |
| return x | |
| def set_cpu_offload(enable: bool, offload_policy: OffloadPolicy = OffloadPolicy.COST_EFFECTIVE): | |
| """ | |
| Context manager to temporarily override the cpu_offload setting in global config. | |
| """ | |
| config = get_compile_config() | |
| original_value = config.offload_config.model_cpu_offload | |
| config.offload_config.model_cpu_offload = enable | |
| original_offload_policy = config.offload_config.offload_policy | |
| config.offload_config.offload_policy = offload_policy | |
| try: | |
| yield | |
| finally: | |
| config.offload_config.model_cpu_offload = original_value | |
| config.offload_config.offload_policy = original_offload_policy | |
| def create_offload_mlp_class(): | |
| """ | |
| Create MLP class at runtime so that @magi_compile decorator captures the *current* config state. | |
| This is necessary because the decorator runs at class definition time. | |
| By defining the class inside a function called within `set_cpu_offload(True)` context, | |
| we ensure the decorator sees `model_cpu_offload=True`. | |
| """ | |
| class OffloadMLP(torch.nn.Module): | |
| config: MLPConfig | |
| def __init__(self, config: MLPConfig): | |
| super().__init__() | |
| self.config = config | |
| self.pre_norm = RMSNorm(config.hidden_size) | |
| self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False, dtype=config.params_dtype) | |
| self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False, dtype=config.params_dtype) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.pre_norm(x).to(torch.bfloat16) | |
| x = self.up_proj(x).to(torch.float32) | |
| x = F.silu(x).to(torch.bfloat16) | |
| x = self.down_proj(x) | |
| return x | |
| return OffloadMLP | |
| def test_cpu_offload_placement(device, mlp_config): | |
| """ | |
| Test that the decorated module stays on CPU when .cuda() is called on parent, | |
| while other modules move correctly. | |
| """ | |
| # Use the context manager to enable CPU offload | |
| with set_cpu_offload(True): | |
| # 1. Initialize the parent model | |
| OffloadMLP = create_offload_mlp_class() | |
| model = TransformerWrapper(mlp_config, mlp_cls=OffloadMLP) | |
| # Verify initial state (everything on CPU by default in PyTorch) | |
| assert model.attention_proj.weight.device.type == "cpu" | |
| assert model.mlp.up_proj.weight.device.type == "cpu" | |
| # 2. Move the model to GPU | |
| # This triggers the _apply hook in _magi_compile | |
| model.cuda() | |
| # 3. Verify devices | |
| # The standard layer should be on GPU | |
| assert model.attention_proj.weight.device.type == "cuda", "Standard layers should move to CUDA" | |
| # The compiled/offloaded layer should stay on CPU | |
| assert ( | |
| model.mlp.up_proj.weight.device.type == "cpu" | |
| ), "Compiled MLP layer should remain on CPU due to offload configuration" | |
| def test_cpu_offload_manual_move(device, mlp_config): | |
| """ | |
| Test that the offload hook only blocks the move ONCE. | |
| Subsequent calls to .to(device) on the specific module should allow movement. | |
| """ | |
| with set_cpu_offload(True): | |
| OffloadMLP = create_offload_mlp_class() | |
| model = TransformerWrapper(mlp_config, mlp_cls=OffloadMLP) | |
| # 1. First move (Should trigger offload logic) | |
| model.cuda() | |
| assert model.mlp.up_proj.weight.device.type == "cpu" | |
| assert model.attention_proj.weight.device.type == "cuda" | |
| # 2. Check if the internal flag is set (optional debugging check) | |
| # Note: This relies on the implementation detail _magi_offloaded_once | |
| if hasattr(model.mlp, "_magi_offloaded_once"): | |
| assert model.mlp._magi_offloaded_once is True | |
| # 3. Second move (Should not take any effect) | |
| model.mlp.to(device) | |
| assert model.mlp.up_proj.weight.device.type == "cpu", "Subsequent .to() calls should not take effect" | |
| def my_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: | |
| return q + k + v | |
| def test_cpu_offload_inference(device, mlp_config): | |
| """ | |
| Test that the offload hook only blocks the move ONCE. | |
| Subsequent calls to .to(device) on the specific module should allow movement. | |
| """ | |
| test_shapes = [ | |
| (4, mlp_config.hidden_size), # Small batch | |
| (8, mlp_config.hidden_size), # Medium batch | |
| (16, mlp_config.hidden_size), # Large batch | |
| # NOTE: compiler will specialize for single token, so we move it to the last | |
| (1, mlp_config.hidden_size), # Single token | |
| ] | |
| with set_cpu_offload(True): | |
| OffloadMLP = create_offload_mlp_class() | |
| model = TransformerWrapper(mlp_config, mlp_cls=OffloadMLP) | |
| # 1. First move (Should trigger offload logic) | |
| model.cuda() | |
| assert model.mlp.up_proj.weight.device.type == "cpu" | |
| assert model.attention_proj.weight.device.type == "cuda" | |
| with torch.no_grad(): | |
| for num_tokens, hidden_size in test_shapes: | |
| input_tensor = torch.randn(num_tokens, hidden_size, device=device, dtype=mlp_config.params_dtype) | |
| output = model(input_tensor) | |
| assert output.shape == ( | |
| num_tokens, | |
| hidden_size, | |
| ), f"For input shape ({num_tokens}, {hidden_size}), output shape should be ({num_tokens}, {hidden_size}), but got {output.shape}" | |
| def test_cpu_offload_heuristic(device, mlp_config): | |
| """ | |
| Test that the heuristic scheduler is working correctly. | |
| """ | |
| test_shapes = [ | |
| (4, mlp_config.hidden_size), # Small batch | |
| (8, mlp_config.hidden_size), # Medium batch | |
| (16, mlp_config.hidden_size), # Large batch | |
| # NOTE: compiler will specialize for single token, so we move it to the last | |
| (1, mlp_config.hidden_size), # Single token | |
| ] | |
| with set_cpu_offload(True, OffloadPolicy.HEURISTIC): | |
| OffloadMLP = create_offload_mlp_class() | |
| model = TransformerWrapper(mlp_config, mlp_cls=OffloadMLP) | |
| model.cuda() | |
| assert model.mlp.up_proj.weight.device.type == "cpu" | |
| assert model.attention_proj.weight.device.type == "cuda" | |
| with torch.no_grad(): | |
| for num_tokens, hidden_size in test_shapes: | |
| input_tensor = torch.randn(num_tokens, hidden_size, device=device, dtype=mlp_config.params_dtype) | |
| output = model(input_tensor) | |
| assert output.shape == ( | |
| num_tokens, | |
| hidden_size, | |
| ), f"For input shape ({num_tokens}, {hidden_size}), output shape should be ({num_tokens}, {hidden_size}), but got {output.shape}" | |
| def test_cpu_offload_different_decoration_styles(device, mlp_config): | |
| """ | |
| Test CPU offload with different decoration styles. | |
| """ | |
| with set_cpu_offload(True): | |
| # 1. Class Decoration | |
| class OffloadMLPClass(torch.nn.Module): | |
| def __init__(self, config: MLPConfig): | |
| super().__init__() | |
| self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False, dtype=config.params_dtype) | |
| self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False, dtype=config.params_dtype) | |
| def forward(self, x): | |
| return self.down_proj(torch.nn.functional.silu(self.up_proj(x))) | |
| # Test class level | |
| model_cls = TransformerWrapper(mlp_config, mlp_cls=OffloadMLPClass) | |
| model_cls.cuda() | |
| assert model_cls.mlp.up_proj.weight.device.type == "cpu" | |
| # 2. Instance Decoration (Should be skipped/not offloaded now) | |
| # class PlainMLP(torch.nn.Module): | |
| # def __init__(self, config: MLPConfig): | |
| # super().__init__() | |
| # self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False, dtype=config.params_dtype) | |
| # self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False, dtype=config.params_dtype) | |
| # | |
| # def forward(self, x): | |
| # return self.down_proj(torch.nn.functional.silu(self.up_proj(x))) | |
| # | |
| # model_inst = TransformerWrapper(mlp_config, mlp_cls=PlainMLP) | |
| # magi_compile(model_inst.mlp, dynamic_arg_dims={"x": 0}) | |
| # model_inst.cuda() | |
| # assert model_inst.mlp.up_proj.weight.device.type == "cpu" | |
| # | |
| # # 3. Method Decoration (Should be skipped/not offloaded now) | |
| # model_mtd = TransformerWrapper(mlp_config, mlp_cls=PlainMLP) | |
| # magi_compile(model_mtd.mlp.forward, dynamic_arg_dims={"x": 0}) | |
| # model_mtd.cuda() | |
| # assert model_mtd.mlp.up_proj.weight.device.type == "cpu" | |
| if __name__ == "__main__": | |
| pytest.main([__file__, "-v"]) | |
Xet Storage Details
- Size:
- 11.1 kB
- Xet hash:
- 2b0a72bf76e8fff6d9f26ec6203b45aae283e07c3644d1327127d09416841250
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.