Spaces:
Runtime error
Runtime error
| # Copyright (c) 2026 SandAI. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from unittest.mock import patch | |
| import pytest | |
| import torch | |
| import torch.fx as fx | |
| from magi_compiler.config import get_compile_config | |
| from magi_compiler.tokenflow.graph_profile import GraphProfileWrapper | |
| from magi_compiler.tokenflow.utils import CompiledTransformerModel, ModelConfig | |
| from magi_compiler.utils import envs | |
| def simple_graph_profile_wrapper() -> GraphProfileWrapper: | |
| class SimpleModel(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.linear = torch.nn.Linear(128, 128) | |
| def forward(self, x): | |
| y = self.linear(x) | |
| z = torch.relu(y) | |
| return z | |
| model = SimpleModel() | |
| graph_module = fx.symbolic_trace(model) | |
| wrapper = GraphProfileWrapper(graph_module) | |
| return wrapper | |
| def test_resolve_symint_expression(simple_graph_profile_wrapper): | |
| wrapper = simple_graph_profile_wrapper | |
| seq_len = 64 | |
| assert wrapper._resolve_symint_expression(128, seq_len) == 128 | |
| class FakeSymInt: | |
| def __init__(self, expr_str): | |
| self.expr_str = expr_str | |
| def __str__(self): | |
| return self.expr_str | |
| with patch("torch.SymInt", FakeSymInt): | |
| sym_simple = torch.SymInt("s0") | |
| res_simple = wrapper._resolve_symint_expression(sym_simple, seq_len) | |
| assert res_simple == 64 | |
| sym_complex = torch.SymInt("s0 * 2 + 10") | |
| res_complex = wrapper._resolve_symint_expression(sym_complex, seq_len) | |
| assert res_complex == 138 | |
| sym_multi = torch.SymInt("s0 + s1") | |
| res_multi = wrapper._resolve_symint_expression(sym_multi, seq_len) | |
| assert res_multi == 128 | |
| def test_generate_real_tensor(simple_graph_profile_wrapper): | |
| seq_len = 1 | |
| class FakeSymInt: | |
| def __init__(self, *args, **kwargs): | |
| pass | |
| def __str__(self): | |
| return "s0 * 64" | |
| def __int__(self): | |
| return seq_len * 64 | |
| with patch("torch.SymInt", FakeSymInt): | |
| wrapper = simple_graph_profile_wrapper | |
| sym_dim = torch.SymInt() | |
| shape = (sym_dim, 128) | |
| stride = (128, 1) | |
| dtype = torch.float32 | |
| device = torch.device("cpu") | |
| tensor = wrapper._generate_real_tensor(shape, stride, dtype, device, seq_len) | |
| assert tensor.shape == (seq_len * 64, 128) | |
| assert tensor.stride() == (128, 1) | |
| assert tensor.dtype == dtype | |
| assert not torch.allclose(tensor, torch.zeros_like(tensor)) | |
| def test_e2e_correctness(): | |
| envs.MAGI_ENABLE_PROFILE = True | |
| # envs.MAGI_ENABLE_FX_GRAPH_VIZ = True | |
| get_compile_config().splitting_ops.extend(["athena::my_attention"]) | |
| performer_config = ModelConfig( | |
| hidden_size=4096, | |
| num_layers=1, | |
| num_heads_q=32, | |
| num_heads_kv=8, | |
| head_dim=128, | |
| intermediate_size=16384, | |
| activation_type="gelu", | |
| ) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| class_constructor = CompiledTransformerModel | |
| model = class_constructor(performer_config).to(device).to(performer_config.params_dtype) | |
| uncompiled_model = model.mod | |
| test_seq_lens = [4096, 1014, 512, 101, 64, 7, 1] | |
| for seq_len in test_seq_lens: | |
| x = torch.randn(seq_len, performer_config.hidden_size, device=device, dtype=performer_config.params_dtype) | |
| with torch.no_grad(): | |
| output = model(x) | |
| uncompiled_output = uncompiled_model(x) | |
| assert torch.allclose(output, uncompiled_output, atol=1e-3) | |
| if __name__ == "__main__": | |
| pytest.main(["-v", __file__]) | |