# Copyright (c) 2026 SandAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from unittest.mock import patch import pytest import torch import torch.fx as fx from magi_compiler.config import get_compile_config from magi_compiler.tokenflow.graph_profile import GraphProfileWrapper from magi_compiler.tokenflow.utils import CompiledTransformerModel, ModelConfig from magi_compiler.utils import envs @pytest.fixture(scope="function") def simple_graph_profile_wrapper() -> GraphProfileWrapper: class SimpleModel(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(128, 128) def forward(self, x): y = self.linear(x) z = torch.relu(y) return z model = SimpleModel() graph_module = fx.symbolic_trace(model) wrapper = GraphProfileWrapper(graph_module) return wrapper def test_resolve_symint_expression(simple_graph_profile_wrapper): wrapper = simple_graph_profile_wrapper seq_len = 64 assert wrapper._resolve_symint_expression(128, seq_len) == 128 class FakeSymInt: def __init__(self, expr_str): self.expr_str = expr_str def __str__(self): return self.expr_str with patch("torch.SymInt", FakeSymInt): sym_simple = torch.SymInt("s0") res_simple = wrapper._resolve_symint_expression(sym_simple, seq_len) assert res_simple == 64 sym_complex = torch.SymInt("s0 * 2 + 10") res_complex = wrapper._resolve_symint_expression(sym_complex, seq_len) assert res_complex == 138 sym_multi = torch.SymInt("s0 + s1") res_multi = wrapper._resolve_symint_expression(sym_multi, seq_len) assert res_multi == 128 def test_generate_real_tensor(simple_graph_profile_wrapper): seq_len = 1 class FakeSymInt: def __init__(self, *args, **kwargs): pass def __str__(self): return "s0 * 64" def __int__(self): return seq_len * 64 with patch("torch.SymInt", FakeSymInt): wrapper = simple_graph_profile_wrapper sym_dim = torch.SymInt() shape = (sym_dim, 128) stride = (128, 1) dtype = torch.float32 device = torch.device("cpu") tensor = wrapper._generate_real_tensor(shape, stride, dtype, device, seq_len) assert tensor.shape == (seq_len * 64, 128) assert tensor.stride() == (128, 1) assert tensor.dtype == dtype assert not torch.allclose(tensor, torch.zeros_like(tensor)) def test_e2e_correctness(): envs.MAGI_ENABLE_PROFILE = True # envs.MAGI_ENABLE_FX_GRAPH_VIZ = True get_compile_config().splitting_ops.extend(["athena::my_attention"]) performer_config = ModelConfig( hidden_size=4096, num_layers=1, num_heads_q=32, num_heads_kv=8, head_dim=128, intermediate_size=16384, activation_type="gelu", ) device = "cuda" if torch.cuda.is_available() else "cpu" class_constructor = CompiledTransformerModel model = class_constructor(performer_config).to(device).to(performer_config.params_dtype) uncompiled_model = model.mod test_seq_lens = [4096, 1014, 512, 101, 64, 7, 1] for seq_len in test_seq_lens: x = torch.randn(seq_len, performer_config.hidden_size, device=device, dtype=performer_config.params_dtype) with torch.no_grad(): output = model(x) uncompiled_output = uncompiled_model(x) assert torch.allclose(output, uncompiled_output, atol=1e-3) if __name__ == "__main__": pytest.main(["-v", __file__])