Spaces:
Runtime error
Runtime error
File size: 9,177 Bytes
e6066e8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 | # Copyright (c) 2025 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Type
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from magi_compiler import magi_compile
from magi_compiler.config import OffloadPolicy, get_compile_config
from .model_definition import MLPConfig, RMSNorm
class TransformerWrapper(nn.Module):
"""
A wrapper class simulating a Transformer Block.
Accepts mlp_cls to support injecting dynamically defined classes.
"""
def __init__(self, config: MLPConfig, mlp_cls: Type[nn.Module]):
super().__init__()
# Standard layer (should move to GPU)
self.attention_proj = nn.Linear(config.hidden_size, config.hidden_size, dtype=config.params_dtype)
# Compiled layer (should stay on CPU if offload is enabled)
self.mlp = mlp_cls(config)
def forward(self, x):
x = self.mlp(x)
x = my_attention(x, x, x)
x = self.attention_proj(x)
return x
@contextmanager
def set_cpu_offload(enable: bool, offload_policy: OffloadPolicy = OffloadPolicy.COST_EFFECTIVE):
"""
Context manager to temporarily override the cpu_offload setting in global config.
"""
config = get_compile_config()
original_value = config.offload_config.model_cpu_offload
config.offload_config.model_cpu_offload = enable
original_offload_policy = config.offload_config.offload_policy
config.offload_config.offload_policy = offload_policy
try:
yield
finally:
config.offload_config.model_cpu_offload = original_value
config.offload_config.offload_policy = original_offload_policy
def create_offload_mlp_class():
"""
Create MLP class at runtime so that @magi_compile decorator captures the *current* config state.
This is necessary because the decorator runs at class definition time.
By defining the class inside a function called within `set_cpu_offload(True)` context,
we ensure the decorator sees `model_cpu_offload=True`.
"""
@magi_compile(dynamic_arg_dims={"x": 0})
class OffloadMLP(torch.nn.Module):
config: MLPConfig
def __init__(self, config: MLPConfig):
super().__init__()
self.config = config
self.pre_norm = RMSNorm(config.hidden_size)
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False, dtype=config.params_dtype)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False, dtype=config.params_dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pre_norm(x).to(torch.bfloat16)
x = self.up_proj(x).to(torch.float32)
x = F.silu(x).to(torch.bfloat16)
x = self.down_proj(x)
return x
return OffloadMLP
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support")
def test_cpu_offload_placement(device, mlp_config):
"""
Test that the decorated module stays on CPU when .cuda() is called on parent,
while other modules move correctly.
"""
# Use the context manager to enable CPU offload
with set_cpu_offload(True):
# 1. Initialize the parent model
OffloadMLP = create_offload_mlp_class()
model = TransformerWrapper(mlp_config, mlp_cls=OffloadMLP)
# Verify initial state (everything on CPU by default in PyTorch)
assert model.attention_proj.weight.device.type == "cpu"
assert model.mlp.up_proj.weight.device.type == "cpu"
# 2. Move the model to GPU
# This triggers the _apply hook in _magi_compile
model.cuda()
# 3. Verify devices
# The standard layer should be on GPU
assert model.attention_proj.weight.device.type == "cuda", "Standard layers should move to CUDA"
# The compiled/offloaded layer should stay on CPU
assert (
model.mlp.up_proj.weight.device.type == "cpu"
), "Compiled MLP layer should remain on CPU due to offload configuration"
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support")
def test_cpu_offload_manual_move(device, mlp_config):
"""
Test that the offload hook only blocks the move ONCE.
Subsequent calls to .to(device) on the specific module should allow movement.
"""
with set_cpu_offload(True):
OffloadMLP = create_offload_mlp_class()
model = TransformerWrapper(mlp_config, mlp_cls=OffloadMLP)
# 1. First move (Should trigger offload logic)
model.cuda()
assert model.mlp.up_proj.weight.device.type == "cpu"
assert model.attention_proj.weight.device.type == "cuda"
# 2. Check if the internal flag is set (optional debugging check)
# Note: This relies on the implementation detail _magi_offloaded_once
if hasattr(model.mlp, "_magi_offloaded_once"):
assert model.mlp._magi_offloaded_once is True
# 3. Second move (Should bypass hook and actually move to GPU)
# Manually force the submodule to GPU
model.mlp.to(device)
assert model.mlp.up_proj.weight.device.type == "cuda", "Subsequent .to() calls should allow moving the module to GPU"
@torch.library.custom_op("athena::my_attention", mutates_args=())
def my_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
return q + k + v
@my_attention.register_fake
def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
return torch.empty_like(q)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support")
def test_cpu_offload_inference(device, mlp_config):
"""
Test that the offload hook only blocks the move ONCE.
Subsequent calls to .to(device) on the specific module should allow movement.
"""
test_shapes = [
(32, mlp_config.hidden_size), # Small batch
(128, mlp_config.hidden_size), # Medium batch
(512, mlp_config.hidden_size), # Large batch
# NOTE: compiler will specialize for single token, so we move it to the last
(1, mlp_config.hidden_size), # Single token
]
with set_cpu_offload(True):
get_compile_config().splitting_ops.extend(["athena::my_attention"])
OffloadMLP = create_offload_mlp_class()
model = TransformerWrapper(mlp_config, mlp_cls=OffloadMLP)
# 1. First move (Should trigger offload logic)
model.cuda()
assert model.mlp.up_proj.weight.device.type == "cpu"
assert model.attention_proj.weight.device.type == "cuda"
with torch.no_grad():
for num_tokens, hidden_size in test_shapes:
input_tensor = torch.randn(num_tokens, hidden_size, device=device, dtype=mlp_config.params_dtype)
output = model(input_tensor)
assert output.shape == (
num_tokens,
hidden_size,
), f"For input shape ({num_tokens}, {hidden_size}), output shape should be ({num_tokens}, {hidden_size}), but got {output.shape}"
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support")
def test_cpu_offload_heuristic(device, mlp_config):
"""
Test that the heuristic scheduler is working correctly.
"""
test_shapes = [
(32, mlp_config.hidden_size), # Small batch
(128, mlp_config.hidden_size), # Medium batch
(512, mlp_config.hidden_size), # Large batch
# NOTE: compiler will specialize for single token, so we move it to the last
(1, mlp_config.hidden_size), # Single token
]
with set_cpu_offload(True, OffloadPolicy.HEURISTIC):
get_compile_config().splitting_ops.extend(["athena::my_attention"])
OffloadMLP = create_offload_mlp_class()
model = TransformerWrapper(mlp_config, mlp_cls=OffloadMLP)
model.cuda()
assert model.mlp.up_proj.weight.device.type == "cpu"
assert model.attention_proj.weight.device.type == "cuda"
with torch.no_grad():
for num_tokens, hidden_size in test_shapes:
input_tensor = torch.randn(num_tokens, hidden_size, device=device, dtype=mlp_config.params_dtype)
output = model(input_tensor)
assert output.shape == (
num_tokens,
hidden_size,
), f"For input shape ({num_tokens}, {hidden_size}), output shape should be ({num_tokens}, {hidden_size}), but got {output.shape}"
if __name__ == "__main__":
pytest.main([__file__, "-v"])
|