rahul7star/MagiGPU-storage / MagiCompiler /tests /feature_tests /test_transformer_cache_reuse.py
download
raw
2.8 kB
# Copyright (c) 2026 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
from __future__ import annotations
import json
import os
import subprocess
import sys
from pathlib import Path
import pytest
import torch
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
@pytest.mark.parametrize("run_mode", ["jit", "aot"])
def test_transformer_cache_reuse_across_processes(tmp_path: Path, run_mode: str):
helper_path = Path(__file__).parent / "cache_reuse_helper" / "transformer_cache_reuse_helper.py"
cache_root = tmp_path / "cache_shared"
baseline_out = tmp_path / f"baseline_{run_mode}.json"
cache_out1 = tmp_path / f"cache1_{run_mode}.json"
cache_out2 = tmp_path / f"cache2_{run_mode}.json"
env = os.environ.copy()
env["MAGI_LOGGING_LEVEL"] = "info"
baseline = subprocess.run(
[
sys.executable,
str(helper_path),
"--cache-root",
str(tmp_path / f"baseline_cache_{run_mode}"),
"--output",
str(baseline_out),
"--run-mode",
run_mode,
"--run-kind",
"baseline",
],
capture_output=True,
text=True,
env=env,
)
assert baseline.returncode == 0, f"baseline process failed (mode={run_mode})\n{baseline.stderr}"
run1 = subprocess.run(
[
sys.executable,
str(helper_path),
"--cache-root",
str(cache_root),
"--output",
str(cache_out1),
"--run-mode",
run_mode,
"--run-kind",
"cache",
],
capture_output=True,
text=True,
env=env,
)
assert run1.returncode == 0, f"cache run1 failed (mode={run_mode})\n{run1.stderr}"
run2 = subprocess.run(
[
sys.executable,
str(helper_path),
"--cache-root",
str(cache_root),
"--output",
str(cache_out2),
"--run-mode",
run_mode,
"--run-kind",
"cache",
],
capture_output=True,
text=True,
env=env,
)
assert run2.returncode == 0, f"cache run2 failed (mode={run_mode})\n{run2.stderr}"
baseline_payload = json.loads(baseline_out.read_text())
payload1 = json.loads(cache_out1.read_text())
payload2 = json.loads(cache_out2.read_text())
expected_shape = baseline_payload["shape"]
assert payload1["shape"] == expected_shape
assert payload2["shape"] == expected_shape
assert abs(payload1["sum"] - payload2["sum"]) < 1e-2
assert abs(payload1["mean"] - payload2["mean"]) < 1e-4

Xet Storage Details

Size:
2.8 kB
·
Xet hash:
0cd2f3b6819cb69a63b2e007e3e392ca47541ab580894572310f3df56f796baa

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.