rahul7star/MagiGPU-storage / MagiCompiler /tests /api_tests /test_dynamic_int_inputs.py
download
raw
7.54 kB
# Copyright (c) 2026 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""End-to-end tests for dynamic int graph inputs in piecewise compilation.
Reproduces the scenario where ``dynamic=True`` + ``assume_static_by_default=False``
causes data-dependent Python ints (e.g. split sizes derived from
``torch.bincount`` / ``tolist``) to become symbolic graph inputs.
Verifies that piecewise compilation handles non-tensor graph inputs correctly.
"""
import os
import pytest
import torch
import torch.nn as nn
from magi_compiler import magi_compile
from magi_compiler.config import CompileMode, get_compile_config
DEVICE = "cuda"
HIDDEN = 64
SEQ_LEN = 128
TOLERANCE = 1e-3
@pytest.fixture(autouse=True)
def _magi_compile_mode():
config = get_compile_config()
old_mode = config.compile_mode
config.compile_mode = CompileMode.MAGI_COMPILE
config.cache_root_dir = os.environ.get("MAGI_COMPILE_CACHE_ROOT_DIR", config.cache_root_dir)
torch._dynamo.reset()
yield
config.compile_mode = old_mode
torch._dynamo.reset()
# ========================= Helpers =========================
class SimpleDispatcher:
"""Minimal reproduction of ModalityDispatcher:
compute split sizes from a tensor, then use them in ``torch.split``.
"""
def __init__(self, modality_mapping: torch.Tensor, num_groups: int):
self.group_size = torch.bincount(modality_mapping, minlength=num_groups).to(torch.int32)
self.group_size_cpu: list[int] = [int(x) for x in self.group_size.to("cpu").tolist()]
self.permute_mapping = torch.argsort(modality_mapping)
self.inv_permute_mapping = torch.argsort(self.permute_mapping)
def dispatch(self, x: torch.Tensor) -> list[torch.Tensor]:
return list(torch.split(x, self.group_size_cpu, dim=0))
def undispatch(self, *groups: torch.Tensor) -> torch.Tensor:
return torch.cat(groups, dim=0)
def _make_inputs(seq_len: int, hidden: int, num_groups: int):
x = torch.randn(seq_len, hidden, device=DEVICE)
modality = torch.randint(0, num_groups, (seq_len,), device=DEVICE)
return x, modality
# ========================= Tests =========================
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
class TestDynamicIntInputs:
"""Regression tests for int graph inputs in piecewise-compiled subgraphs."""
def test_jit_compile_with_split_sizes(self):
"""JIT compile succeeds when ``torch.split`` receives data-dependent
int sizes that become symbolic graph inputs.
"""
@magi_compile(dynamic_arg_dims={"x": 0, "permute_mapping": 0})
class _Block(nn.Module):
def __init__(self):
super().__init__()
self.linears = nn.ModuleList([nn.Linear(HIDDEN, HIDDEN) for _ in range(3)])
def forward(self, x: torch.Tensor, permute_mapping: torch.Tensor, dispatcher: SimpleDispatcher):
x = x[permute_mapping]
groups = dispatcher.dispatch(x)
processed = [self.linears[i](g) for i, g in enumerate(groups)]
return dispatcher.undispatch(*processed)[dispatcher.inv_permute_mapping]
class _Outer(nn.Module):
def __init__(self):
super().__init__()
self.block = _Block()
def forward(self, x, modality):
d = SimpleDispatcher(modality, 3)
return self.block(x, d.permute_mapping, d)
model = _Outer().to(DEVICE)
x, modality = _make_inputs(SEQ_LEN, HIDDEN, 3)
with torch.no_grad():
out = model(x, modality)
assert out.shape == x.shape
def test_jit_compile_numerical_correctness(self):
"""Compiled output matches eager forward output."""
@magi_compile(dynamic_arg_dims={"x": 0, "permute_mapping": 0})
class _Block(nn.Module):
def __init__(self):
super().__init__()
self.linears = nn.ModuleList([nn.Linear(HIDDEN, HIDDEN) for _ in range(3)])
def forward(self, x: torch.Tensor, permute_mapping: torch.Tensor, dispatcher: SimpleDispatcher):
x = x[permute_mapping]
groups = dispatcher.dispatch(x)
processed = [self.linears[i](g) for i, g in enumerate(groups)]
return dispatcher.undispatch(*processed)[dispatcher.inv_permute_mapping]
class _Outer(nn.Module):
def __init__(self):
super().__init__()
self.block = _Block()
def forward(self, x, modality):
d = SimpleDispatcher(modality, 3)
return self.block(x, d.permute_mapping, d)
model = _Outer().to(DEVICE)
x, modality = _make_inputs(SEQ_LEN, HIDDEN, 3)
# Eager: call block.forward() directly, bypassing @magi_compile wrapper
dispatcher = SimpleDispatcher(modality, 3)
with torch.no_grad():
eager_out = model.block.forward(x, dispatcher.permute_mapping, dispatcher)
# Compiled: goes through @magi_compile
with torch.no_grad():
compiled_out = model(x, modality)
assert model.block._magi is not None
assert model.block._magi.compiled_code is not None
assert torch.allclose(
eager_out, compiled_out, atol=TOLERANCE, rtol=TOLERANCE
), f"max diff = {(eager_out - compiled_out).abs().max().item()}"
def test_jit_compile_different_distributions(self):
"""Same sequence length but different modality distributions:
the symbolic int split sizes vary across calls.
"""
@magi_compile(dynamic_arg_dims={"x": 0, "permute_mapping": 0})
class _Block(nn.Module):
def __init__(self):
super().__init__()
self.linears = nn.ModuleList([nn.Linear(HIDDEN, HIDDEN) for _ in range(3)])
def forward(self, x: torch.Tensor, permute_mapping: torch.Tensor, dispatcher: SimpleDispatcher):
x = x[permute_mapping]
groups = dispatcher.dispatch(x)
processed = [self.linears[i](g) for i, g in enumerate(groups)]
return dispatcher.undispatch(*processed)[dispatcher.inv_permute_mapping]
class _Outer(nn.Module):
def __init__(self):
super().__init__()
self.block = _Block()
def forward(self, x, modality):
d = SimpleDispatcher(modality, 3)
return self.block(x, d.permute_mapping, d)
model = _Outer().to(DEVICE)
x = torch.randn(SEQ_LEN, HIDDEN, device=DEVICE)
for seed in range(3):
torch.manual_seed(seed)
modality = torch.randint(0, 3, (SEQ_LEN,), device=DEVICE)
with torch.no_grad():
out = model(x, modality)
assert out.shape == x.shape, f"shape mismatch at seed={seed}"
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

Xet Storage Details

Size:
7.54 kB
·
Xet hash:
d302d0d79d87c8ff575d492b603e2580974b79d95ddc8d17549ea6915ec4e8e6

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.