Buckets:
| # Copyright (c) 2026 SandAI. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """End-to-end tests for dynamic int graph inputs in piecewise compilation. | |
| Reproduces the scenario where ``dynamic=True`` + ``assume_static_by_default=False`` | |
| causes data-dependent Python ints (e.g. split sizes derived from | |
| ``torch.bincount`` / ``tolist``) to become symbolic graph inputs. | |
| Verifies that piecewise compilation handles non-tensor graph inputs correctly. | |
| """ | |
| import os | |
| import pytest | |
| import torch | |
| import torch.nn as nn | |
| from magi_compiler import magi_compile | |
| from magi_compiler.config import CompileMode, get_compile_config | |
| DEVICE = "cuda" | |
| HIDDEN = 64 | |
| SEQ_LEN = 128 | |
| TOLERANCE = 1e-3 | |
| def _magi_compile_mode(): | |
| config = get_compile_config() | |
| old_mode = config.compile_mode | |
| config.compile_mode = CompileMode.MAGI_COMPILE | |
| config.cache_root_dir = os.environ.get("MAGI_COMPILE_CACHE_ROOT_DIR", config.cache_root_dir) | |
| torch._dynamo.reset() | |
| yield | |
| config.compile_mode = old_mode | |
| torch._dynamo.reset() | |
| # ========================= Helpers ========================= | |
| class SimpleDispatcher: | |
| """Minimal reproduction of ModalityDispatcher: | |
| compute split sizes from a tensor, then use them in ``torch.split``. | |
| """ | |
| def __init__(self, modality_mapping: torch.Tensor, num_groups: int): | |
| self.group_size = torch.bincount(modality_mapping, minlength=num_groups).to(torch.int32) | |
| self.group_size_cpu: list[int] = [int(x) for x in self.group_size.to("cpu").tolist()] | |
| self.permute_mapping = torch.argsort(modality_mapping) | |
| self.inv_permute_mapping = torch.argsort(self.permute_mapping) | |
| def dispatch(self, x: torch.Tensor) -> list[torch.Tensor]: | |
| return list(torch.split(x, self.group_size_cpu, dim=0)) | |
| def undispatch(self, *groups: torch.Tensor) -> torch.Tensor: | |
| return torch.cat(groups, dim=0) | |
| def _make_inputs(seq_len: int, hidden: int, num_groups: int): | |
| x = torch.randn(seq_len, hidden, device=DEVICE) | |
| modality = torch.randint(0, num_groups, (seq_len,), device=DEVICE) | |
| return x, modality | |
| # ========================= Tests ========================= | |
| class TestDynamicIntInputs: | |
| """Regression tests for int graph inputs in piecewise-compiled subgraphs.""" | |
| def test_jit_compile_with_split_sizes(self): | |
| """JIT compile succeeds when ``torch.split`` receives data-dependent | |
| int sizes that become symbolic graph inputs. | |
| """ | |
| class _Block(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.linears = nn.ModuleList([nn.Linear(HIDDEN, HIDDEN) for _ in range(3)]) | |
| def forward(self, x: torch.Tensor, permute_mapping: torch.Tensor, dispatcher: SimpleDispatcher): | |
| x = x[permute_mapping] | |
| groups = dispatcher.dispatch(x) | |
| processed = [self.linears[i](g) for i, g in enumerate(groups)] | |
| return dispatcher.undispatch(*processed)[dispatcher.inv_permute_mapping] | |
| class _Outer(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.block = _Block() | |
| def forward(self, x, modality): | |
| d = SimpleDispatcher(modality, 3) | |
| return self.block(x, d.permute_mapping, d) | |
| model = _Outer().to(DEVICE) | |
| x, modality = _make_inputs(SEQ_LEN, HIDDEN, 3) | |
| with torch.no_grad(): | |
| out = model(x, modality) | |
| assert out.shape == x.shape | |
| def test_jit_compile_numerical_correctness(self): | |
| """Compiled output matches eager forward output.""" | |
| class _Block(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.linears = nn.ModuleList([nn.Linear(HIDDEN, HIDDEN) for _ in range(3)]) | |
| def forward(self, x: torch.Tensor, permute_mapping: torch.Tensor, dispatcher: SimpleDispatcher): | |
| x = x[permute_mapping] | |
| groups = dispatcher.dispatch(x) | |
| processed = [self.linears[i](g) for i, g in enumerate(groups)] | |
| return dispatcher.undispatch(*processed)[dispatcher.inv_permute_mapping] | |
| class _Outer(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.block = _Block() | |
| def forward(self, x, modality): | |
| d = SimpleDispatcher(modality, 3) | |
| return self.block(x, d.permute_mapping, d) | |
| model = _Outer().to(DEVICE) | |
| x, modality = _make_inputs(SEQ_LEN, HIDDEN, 3) | |
| # Eager: call block.forward() directly, bypassing @magi_compile wrapper | |
| dispatcher = SimpleDispatcher(modality, 3) | |
| with torch.no_grad(): | |
| eager_out = model.block.forward(x, dispatcher.permute_mapping, dispatcher) | |
| # Compiled: goes through @magi_compile | |
| with torch.no_grad(): | |
| compiled_out = model(x, modality) | |
| assert model.block._magi is not None | |
| assert model.block._magi.compiled_code is not None | |
| assert torch.allclose( | |
| eager_out, compiled_out, atol=TOLERANCE, rtol=TOLERANCE | |
| ), f"max diff = {(eager_out - compiled_out).abs().max().item()}" | |
| def test_jit_compile_different_distributions(self): | |
| """Same sequence length but different modality distributions: | |
| the symbolic int split sizes vary across calls. | |
| """ | |
| class _Block(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.linears = nn.ModuleList([nn.Linear(HIDDEN, HIDDEN) for _ in range(3)]) | |
| def forward(self, x: torch.Tensor, permute_mapping: torch.Tensor, dispatcher: SimpleDispatcher): | |
| x = x[permute_mapping] | |
| groups = dispatcher.dispatch(x) | |
| processed = [self.linears[i](g) for i, g in enumerate(groups)] | |
| return dispatcher.undispatch(*processed)[dispatcher.inv_permute_mapping] | |
| class _Outer(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.block = _Block() | |
| def forward(self, x, modality): | |
| d = SimpleDispatcher(modality, 3) | |
| return self.block(x, d.permute_mapping, d) | |
| model = _Outer().to(DEVICE) | |
| x = torch.randn(SEQ_LEN, HIDDEN, device=DEVICE) | |
| for seed in range(3): | |
| torch.manual_seed(seed) | |
| modality = torch.randint(0, 3, (SEQ_LEN,), device=DEVICE) | |
| with torch.no_grad(): | |
| out = model(x, modality) | |
| assert out.shape == x.shape, f"shape mismatch at seed={seed}" | |
| if __name__ == "__main__": | |
| pytest.main([__file__, "-v", "-s"]) | |
Xet Storage Details
- Size:
- 7.54 kB
- Xet hash:
- d302d0d79d87c8ff575d492b603e2580974b79d95ddc8d17549ea6915ec4e8e6
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.