download
raw
11.1 kB
# Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Type
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from magi_compiler import magi_compile, magi_register_custom_op
from magi_compiler.config import OffloadPolicy, get_compile_config
from tests.model_definition import MLPConfig, RMSNorm
class TransformerWrapper(nn.Module):
"""
A wrapper class simulating a Transformer Block.
Accepts mlp_cls to support injecting dynamically defined classes.
"""
def __init__(self, config: MLPConfig, mlp_cls: Type[nn.Module]):
super().__init__()
# Standard layer (should move to GPU)
self.attention_proj = nn.Linear(config.hidden_size, config.hidden_size, dtype=config.params_dtype)
# Compiled layer (should stay on CPU if offload is enabled)
self.mlp = mlp_cls(config)
def forward(self, x):
x = self.mlp(x)
x = my_attention(x, x, x)
x = self.attention_proj(x)
return x
@contextmanager
def set_cpu_offload(enable: bool, offload_policy: OffloadPolicy = OffloadPolicy.COST_EFFECTIVE):
"""
Context manager to temporarily override the cpu_offload setting in global config.
"""
config = get_compile_config()
original_value = config.offload_config.model_cpu_offload
config.offload_config.model_cpu_offload = enable
original_offload_policy = config.offload_config.offload_policy
config.offload_config.offload_policy = offload_policy
try:
yield
finally:
config.offload_config.model_cpu_offload = original_value
config.offload_config.offload_policy = original_offload_policy
def create_offload_mlp_class():
"""
Create MLP class at runtime so that @magi_compile decorator captures the *current* config state.
This is necessary because the decorator runs at class definition time.
By defining the class inside a function called within `set_cpu_offload(True)` context,
we ensure the decorator sees `model_cpu_offload=True`.
"""
@magi_compile(dynamic_arg_dims={"x": 0})
class OffloadMLP(torch.nn.Module):
config: MLPConfig
def __init__(self, config: MLPConfig):
super().__init__()
self.config = config
self.pre_norm = RMSNorm(config.hidden_size)
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False, dtype=config.params_dtype)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False, dtype=config.params_dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pre_norm(x).to(torch.bfloat16)
x = self.up_proj(x).to(torch.float32)
x = F.silu(x).to(torch.bfloat16)
x = self.down_proj(x)
return x
return OffloadMLP
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support")
def test_cpu_offload_placement(device, mlp_config):
"""
Test that the decorated module stays on CPU when .cuda() is called on parent,
while other modules move correctly.
"""
# Use the context manager to enable CPU offload
with set_cpu_offload(True):
# 1. Initialize the parent model
OffloadMLP = create_offload_mlp_class()
model = TransformerWrapper(mlp_config, mlp_cls=OffloadMLP)
# Verify initial state (everything on CPU by default in PyTorch)
assert model.attention_proj.weight.device.type == "cpu"
assert model.mlp.up_proj.weight.device.type == "cpu"
# 2. Move the model to GPU
# This triggers the _apply hook in _magi_compile
model.cuda()
# 3. Verify devices
# The standard layer should be on GPU
assert model.attention_proj.weight.device.type == "cuda", "Standard layers should move to CUDA"
# The compiled/offloaded layer should stay on CPU
assert (
model.mlp.up_proj.weight.device.type == "cpu"
), "Compiled MLP layer should remain on CPU due to offload configuration"
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support")
def test_cpu_offload_manual_move(device, mlp_config):
"""
Test that the offload hook only blocks the move ONCE.
Subsequent calls to .to(device) on the specific module should allow movement.
"""
with set_cpu_offload(True):
OffloadMLP = create_offload_mlp_class()
model = TransformerWrapper(mlp_config, mlp_cls=OffloadMLP)
# 1. First move (Should trigger offload logic)
model.cuda()
assert model.mlp.up_proj.weight.device.type == "cpu"
assert model.attention_proj.weight.device.type == "cuda"
# 2. Check if the internal flag is set (optional debugging check)
# Note: This relies on the implementation detail _magi_offloaded_once
if hasattr(model.mlp, "_magi_offloaded_once"):
assert model.mlp._magi_offloaded_once is True
# 3. Second move (Should not take any effect)
model.mlp.to(device)
assert model.mlp.up_proj.weight.device.type == "cpu", "Subsequent .to() calls should not take effect"
@magi_register_custom_op("athena::my_attention", infer_output_meta_fn=["q"], is_subgraph_boundary=True)
def my_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
return q + k + v
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support")
def test_cpu_offload_inference(device, mlp_config):
"""
Test that the offload hook only blocks the move ONCE.
Subsequent calls to .to(device) on the specific module should allow movement.
"""
test_shapes = [
(4, mlp_config.hidden_size), # Small batch
(8, mlp_config.hidden_size), # Medium batch
(16, mlp_config.hidden_size), # Large batch
# NOTE: compiler will specialize for single token, so we move it to the last
(1, mlp_config.hidden_size), # Single token
]
with set_cpu_offload(True):
OffloadMLP = create_offload_mlp_class()
model = TransformerWrapper(mlp_config, mlp_cls=OffloadMLP)
# 1. First move (Should trigger offload logic)
model.cuda()
assert model.mlp.up_proj.weight.device.type == "cpu"
assert model.attention_proj.weight.device.type == "cuda"
with torch.no_grad():
for num_tokens, hidden_size in test_shapes:
input_tensor = torch.randn(num_tokens, hidden_size, device=device, dtype=mlp_config.params_dtype)
output = model(input_tensor)
assert output.shape == (
num_tokens,
hidden_size,
), f"For input shape ({num_tokens}, {hidden_size}), output shape should be ({num_tokens}, {hidden_size}), but got {output.shape}"
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support")
def test_cpu_offload_heuristic(device, mlp_config):
"""
Test that the heuristic scheduler is working correctly.
"""
test_shapes = [
(4, mlp_config.hidden_size), # Small batch
(8, mlp_config.hidden_size), # Medium batch
(16, mlp_config.hidden_size), # Large batch
# NOTE: compiler will specialize for single token, so we move it to the last
(1, mlp_config.hidden_size), # Single token
]
with set_cpu_offload(True, OffloadPolicy.HEURISTIC):
OffloadMLP = create_offload_mlp_class()
model = TransformerWrapper(mlp_config, mlp_cls=OffloadMLP)
model.cuda()
assert model.mlp.up_proj.weight.device.type == "cpu"
assert model.attention_proj.weight.device.type == "cuda"
with torch.no_grad():
for num_tokens, hidden_size in test_shapes:
input_tensor = torch.randn(num_tokens, hidden_size, device=device, dtype=mlp_config.params_dtype)
output = model(input_tensor)
assert output.shape == (
num_tokens,
hidden_size,
), f"For input shape ({num_tokens}, {hidden_size}), output shape should be ({num_tokens}, {hidden_size}), but got {output.shape}"
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support")
def test_cpu_offload_different_decoration_styles(device, mlp_config):
"""
Test CPU offload with different decoration styles.
"""
with set_cpu_offload(True):
# 1. Class Decoration
@magi_compile(dynamic_arg_dims={"x": 0})
class OffloadMLPClass(torch.nn.Module):
def __init__(self, config: MLPConfig):
super().__init__()
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False, dtype=config.params_dtype)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False, dtype=config.params_dtype)
def forward(self, x):
return self.down_proj(torch.nn.functional.silu(self.up_proj(x)))
# Test class level
model_cls = TransformerWrapper(mlp_config, mlp_cls=OffloadMLPClass)
model_cls.cuda()
assert model_cls.mlp.up_proj.weight.device.type == "cpu"
# 2. Instance Decoration (Should be skipped/not offloaded now)
# class PlainMLP(torch.nn.Module):
# def __init__(self, config: MLPConfig):
# super().__init__()
# self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False, dtype=config.params_dtype)
# self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False, dtype=config.params_dtype)
#
# def forward(self, x):
# return self.down_proj(torch.nn.functional.silu(self.up_proj(x)))
#
# model_inst = TransformerWrapper(mlp_config, mlp_cls=PlainMLP)
# magi_compile(model_inst.mlp, dynamic_arg_dims={"x": 0})
# model_inst.cuda()
# assert model_inst.mlp.up_proj.weight.device.type == "cpu"
#
# # 3. Method Decoration (Should be skipped/not offloaded now)
# model_mtd = TransformerWrapper(mlp_config, mlp_cls=PlainMLP)
# magi_compile(model_mtd.mlp.forward, dynamic_arg_dims={"x": 0})
# model_mtd.cuda()
# assert model_mtd.mlp.up_proj.weight.device.type == "cpu"
if __name__ == "__main__":
pytest.main([__file__, "-v"])

Xet Storage Details

Size:
11.1 kB
·
Xet hash:
2b0a72bf76e8fff6d9f26ec6203b45aae283e07c3644d1327127d09416841250

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.