daVinci-MagiHuman / pkgs /MagiCompiler /tests /tokenflow /test_graph_profile.py
jiadisu
Switch back to Docker SDK with local pkgs
e6066e8
# Copyright (c) 2026 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import patch
import pytest
import torch
import torch.fx as fx
from magi_compiler.config import get_compile_config
from magi_compiler.tokenflow.graph_profile import GraphProfileWrapper
from magi_compiler.tokenflow.utils import CompiledTransformerModel, ModelConfig
from magi_compiler.utils import envs
@pytest.fixture(scope="function")
def simple_graph_profile_wrapper() -> GraphProfileWrapper:
class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(128, 128)
def forward(self, x):
y = self.linear(x)
z = torch.relu(y)
return z
model = SimpleModel()
graph_module = fx.symbolic_trace(model)
wrapper = GraphProfileWrapper(graph_module)
return wrapper
def test_resolve_symint_expression(simple_graph_profile_wrapper):
wrapper = simple_graph_profile_wrapper
seq_len = 64
assert wrapper._resolve_symint_expression(128, seq_len) == 128
class FakeSymInt:
def __init__(self, expr_str):
self.expr_str = expr_str
def __str__(self):
return self.expr_str
with patch("torch.SymInt", FakeSymInt):
sym_simple = torch.SymInt("s0")
res_simple = wrapper._resolve_symint_expression(sym_simple, seq_len)
assert res_simple == 64
sym_complex = torch.SymInt("s0 * 2 + 10")
res_complex = wrapper._resolve_symint_expression(sym_complex, seq_len)
assert res_complex == 138
sym_multi = torch.SymInt("s0 + s1")
res_multi = wrapper._resolve_symint_expression(sym_multi, seq_len)
assert res_multi == 128
def test_generate_real_tensor(simple_graph_profile_wrapper):
seq_len = 1
class FakeSymInt:
def __init__(self, *args, **kwargs):
pass
def __str__(self):
return "s0 * 64"
def __int__(self):
return seq_len * 64
with patch("torch.SymInt", FakeSymInt):
wrapper = simple_graph_profile_wrapper
sym_dim = torch.SymInt()
shape = (sym_dim, 128)
stride = (128, 1)
dtype = torch.float32
device = torch.device("cpu")
tensor = wrapper._generate_real_tensor(shape, stride, dtype, device, seq_len)
assert tensor.shape == (seq_len * 64, 128)
assert tensor.stride() == (128, 1)
assert tensor.dtype == dtype
assert not torch.allclose(tensor, torch.zeros_like(tensor))
def test_e2e_correctness():
envs.MAGI_ENABLE_PROFILE = True
# envs.MAGI_ENABLE_FX_GRAPH_VIZ = True
get_compile_config().splitting_ops.extend(["athena::my_attention"])
performer_config = ModelConfig(
hidden_size=4096,
num_layers=1,
num_heads_q=32,
num_heads_kv=8,
head_dim=128,
intermediate_size=16384,
activation_type="gelu",
)
device = "cuda" if torch.cuda.is_available() else "cpu"
class_constructor = CompiledTransformerModel
model = class_constructor(performer_config).to(device).to(performer_config.params_dtype)
uncompiled_model = model.mod
test_seq_lens = [4096, 1014, 512, 101, 64, 7, 1]
for seq_len in test_seq_lens:
x = torch.randn(seq_len, performer_config.hidden_size, device=device, dtype=performer_config.params_dtype)
with torch.no_grad():
output = model(x)
uncompiled_output = uncompiled_model(x)
assert torch.allclose(output, uncompiled_output, atol=1e-3)
if __name__ == "__main__":
pytest.main(["-v", __file__])