| import io |
| import multiprocessing |
| import os |
| import sys |
| import threading |
| import time |
| from contextlib import contextmanager |
| from pathlib import Path |
|
|
| import pytest |
| import requests |
| import torch |
| import torch.distributed as dist |
|
|
| from sglang.srt.debug_utils.dumper import ( |
| DumperConfig, |
| _collective_with_timeout, |
| _deepcopy_or_clone, |
| _detect_recompute_status, |
| _Dumper, |
| _format_tags, |
| _get_default_exp_name, |
| _map_tensor, |
| _materialize_value, |
| _MegatronPlugin, |
| _obj_to_dict, |
| _RecomputeStatus, |
| _register_forward_hook_or_replace_fn, |
| _SGLangPlugin, |
| _torch_save, |
| dumper, |
| get_tensor_info, |
| get_truncated_value, |
| ) |
| from sglang.srt.environ import temp_set_env |
| from sglang.srt.utils import kill_process_tree |
| from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci |
| from sglang.test.test_utils import ( |
| DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| DEFAULT_URL_FOR_TEST, |
| find_available_port, |
| popen_launch_server, |
| run_distributed_test, |
| ) |
|
|
| register_cuda_ci(est_time=30, suite="nightly-2-gpu", nightly=True) |
| register_amd_ci(est_time=60, suite="nightly-amd", nightly=True) |
|
|
|
|
| @contextmanager |
| def _capture_stdout(): |
| captured = io.StringIO() |
| old_stdout = sys.stdout |
| sys.stdout = captured |
| try: |
| yield captured |
| finally: |
| sys.stdout = old_stdout |
|
|
|
|
| class TestDumperConfig: |
| def test_from_env_defaults_match_dataclass_defaults(self): |
| assert DumperConfig.from_env() == DumperConfig() |
|
|
| def test_from_env_bool(self): |
| with temp_set_env(DUMPER_ENABLE="1"): |
| assert DumperConfig.from_env().enable is True |
| with temp_set_env(DUMPER_ENABLE="false"): |
| assert DumperConfig.from_env().enable is False |
|
|
| def test_from_env_str(self): |
| with temp_set_env(DUMPER_FILTER="layer_id=0"): |
| assert DumperConfig.from_env().filter == "layer_id=0" |
|
|
| def test_from_env_dir(self): |
| with temp_set_env(DUMPER_DIR="/my/dir"): |
| assert DumperConfig.from_env().dir == "/my/dir" |
|
|
| def test_from_env_int(self): |
| with temp_set_env(DUMPER_COLLECTIVE_TIMEOUT="120"): |
| assert DumperConfig.from_env().collective_timeout == 120 |
|
|
| def test_configure_overrides(self): |
| d = _make_test_dumper("/tmp") |
| d.configure(enable=False) |
| assert d._config.enable is False |
| d.configure(enable=True) |
| assert d._config.enable is True |
|
|
| def test_type_validation(self): |
| with pytest.raises(TypeError, match="enable.*expected bool.*got str"): |
| DumperConfig(enable="yes") |
| with pytest.raises( |
| TypeError, match="collective_timeout.*expected int.*got str" |
| ): |
| DumperConfig(collective_timeout="abc") |
| with pytest.raises(TypeError, match="filter.*expected str.*got int"): |
| DumperConfig(filter=123) |
|
|
| def test_configure_default_skips_when_env_set(self): |
| with temp_set_env(DUMPER_FILTER="from_env"): |
| d = _Dumper(config=DumperConfig.from_env()) |
| d.configure_default(filter="from_code") |
| assert d._config.filter == "from_env" |
|
|
| def test_configure_default_applies_when_no_env(self): |
| d = _Dumper(config=DumperConfig.from_env()) |
| d.configure_default(filter="from_code") |
| assert d._config.filter == "from_code" |
|
|
| def test_from_env_whitespace_treated_as_unset(self): |
| with temp_set_env(DUMPER_FILTER=" "): |
| assert DumperConfig.from_env().filter is None |
|
|
| def test_may_enable_default_false(self): |
| d = _Dumper(config=DumperConfig()) |
| assert d.may_enable is False |
|
|
| def test_may_enable_true_when_enabled(self): |
| d = _Dumper(config=DumperConfig(enable=True)) |
| assert d.may_enable is True |
|
|
| def test_may_enable_true_when_server_port_set(self): |
| d = _Dumper(config=DumperConfig(server_port="40000")) |
| assert d.may_enable is True |
|
|
| d2 = _Dumper(config=DumperConfig(server_port="reuse")) |
| assert d2.may_enable is True |
|
|
|
|
| class TestServerPortParsed: |
| def test_negative_returns_none(self): |
| assert DumperConfig(server_port="-1").server_port_parsed is None |
|
|
| def test_zero_returns_none(self): |
| assert DumperConfig(server_port="0").server_port_parsed is None |
|
|
| def test_positive_returns_int(self): |
| result = DumperConfig(server_port="40000").server_port_parsed |
| assert result == 40000 |
| assert isinstance(result, int) |
|
|
| def test_reuse_returns_string(self): |
| assert DumperConfig(server_port="reuse").server_port_parsed == "reuse" |
|
|
|
|
| class TestDefaultExpName: |
| def test_starts_with_prefix(self): |
| name = _get_default_exp_name(timeout_seconds=5) |
| assert name.startswith("dump_") |
|
|
| def test_suffix_format(self): |
| name = _get_default_exp_name(timeout_seconds=5) |
| suffix = name[len("dump_") :] |
| assert len(suffix) == 22 |
| assert suffix[8] == "_" |
|
|
|
|
| class TestKvPairsParsing: |
| def test_from_kv_pairs_none_returns_defaults(self): |
| assert DumperConfig.from_kv_pairs(None) == DumperConfig() |
|
|
| def test_from_kv_pairs_empty_returns_defaults(self): |
| assert DumperConfig.from_kv_pairs([]) == DumperConfig() |
|
|
| def test_from_kv_pairs_bool_field(self): |
| cfg = DumperConfig.from_kv_pairs(["enable=true"]) |
| assert cfg.enable is True |
| assert cfg.dir == "/tmp/dumper" |
|
|
| def test_from_kv_pairs_bool_numeric(self): |
| assert DumperConfig.from_kv_pairs(["enable=1"]).enable is True |
| assert DumperConfig.from_kv_pairs(["enable=0"]).enable is False |
|
|
| def test_from_kv_pairs_int_field(self): |
| cfg = DumperConfig.from_kv_pairs(["collective_timeout=120"]) |
| assert cfg.collective_timeout == 120 |
| assert type(cfg.collective_timeout) is int |
|
|
| def test_from_kv_pairs_int_field_zero_stays_int(self): |
| cfg = DumperConfig.from_kv_pairs(["collective_timeout=0"]) |
| assert cfg.collective_timeout == 0 |
| assert type(cfg.collective_timeout) is int |
|
|
| def test_from_kv_pairs_str_field_not_coerced(self): |
| cfg = DumperConfig.from_kv_pairs(["server_port=0"]) |
| assert cfg.server_port == "0" |
| assert type(cfg.server_port) is str |
|
|
| def test_from_kv_pairs_str_field_one_stays_str(self): |
| cfg = DumperConfig.from_kv_pairs(["server_port=1"]) |
| assert cfg.server_port == "1" |
| assert type(cfg.server_port) is str |
|
|
| def test_from_kv_pairs_optional_str_field(self): |
| cfg = DumperConfig.from_kv_pairs( |
| ["filter=layer_id is not None and layer_id < 3"] |
| ) |
| assert cfg.filter == "layer_id is not None and layer_id < 3" |
|
|
| def test_from_kv_pairs_optional_str_exp_name(self): |
| cfg = DumperConfig.from_kv_pairs(["exp_name=my_experiment"]) |
| assert cfg.exp_name == "my_experiment" |
|
|
| def test_from_kv_pairs_multiple_fields(self): |
| cfg = DumperConfig.from_kv_pairs( |
| [ |
| "enable=true", |
| "dir=/my/dir", |
| "filter=name == 'foo'", |
| "collective_timeout=30", |
| "enable_grad=1", |
| ] |
| ) |
| assert cfg.enable is True |
| assert cfg.dir == "/my/dir" |
| assert cfg.filter == "name == 'foo'" |
| assert cfg.collective_timeout == 30 |
| assert cfg.enable_grad is True |
|
|
| def test_from_kv_pairs_missing_equals_raises(self): |
| with pytest.raises(ValueError, match="missing '='"): |
| DumperConfig.from_kv_pairs(["enable"]) |
|
|
| def test_from_kv_pairs_unknown_key_raises(self): |
| with pytest.raises(ValueError, match="Unknown config key"): |
| DumperConfig.from_kv_pairs(["nonexistent=true"]) |
|
|
| def test_kv_pairs_to_dict_returns_only_explicit(self): |
| d = DumperConfig._kv_pairs_to_dict(["enable=true", "dir=/x"]) |
| assert d == {"enable": True, "dir": "/x"} |
| assert "filter" not in d |
| assert "collective_timeout" not in d |
|
|
| def test_kv_pairs_to_dict_none_returns_empty(self): |
| assert DumperConfig._kv_pairs_to_dict(None) == {} |
|
|
| def test_kv_pairs_to_dict_empty_returns_empty(self): |
| assert DumperConfig._kv_pairs_to_dict([]) == {} |
|
|
| def test_from_kv_pairs_value_with_equals_in_value(self): |
| cfg = DumperConfig.from_kv_pairs(["filter=name == 'foo'"]) |
| assert cfg.filter == "name == 'foo'" |
|
|
| def test_from_kv_pairs_type_validation_still_works(self): |
| with pytest.raises(TypeError, match="collective_timeout.*expected int"): |
| DumperConfig.from_kv_pairs(["collective_timeout=not_a_number"]) |
|
|
|
|
| class TestDumperPureFunctions: |
| def test_get_truncated_value(self): |
| assert get_truncated_value(None) is None |
| assert get_truncated_value(42) == 42 |
| assert len(get_truncated_value((torch.randn(10), torch.randn(20)))) == 2 |
| assert get_truncated_value(torch.randn(10, 10)).shape == (10, 10) |
| assert get_truncated_value(torch.randn(100, 100)).shape == (5, 5) |
|
|
| def test_obj_to_dict(self): |
| assert _obj_to_dict({"a": 1}) == {"a": 1} |
|
|
| class Obj: |
| x, y = 10, 20 |
|
|
| def method(self): |
| pass |
|
|
| result = _obj_to_dict(Obj()) |
| assert result["x"] == 10 |
| assert "method" not in result |
|
|
| def test_deepcopy_or_clone_tensor(self): |
| original = torch.randn(3, 3) |
| cloned = _deepcopy_or_clone(original) |
| assert torch.equal(cloned, original) |
| original.fill_(999.0) |
| assert not torch.equal(cloned, original) |
|
|
| def test_deepcopy_or_clone_non_tensor(self): |
| original = {"a": [1, 2, 3]} |
| cloned = _deepcopy_or_clone(original) |
| assert cloned == original |
| assert cloned is not original |
| original["a"].append(4) |
| assert len(cloned["a"]) == 3 |
|
|
| def test_get_tensor_info(self): |
| info = get_tensor_info(torch.randn(10, 10)) |
| for key in ["shape=", "dtype=", "min=", "max=", "mean="]: |
| assert key in info |
|
|
| assert "value=42" in get_tensor_info(42) |
| assert "min=None" in get_tensor_info(torch.tensor([])) |
|
|
|
|
| class TestMapTensor: |
| def test_bare_tensor(self): |
| t = torch.randn(4) |
| result = _map_tensor(t, lambda x: x * 2) |
| assert torch.equal(result, t * 2) |
|
|
| def test_bare_tensor_no_change(self): |
| t = torch.randn(4) |
| result = _map_tensor(t, lambda x: x) |
| assert result is t |
|
|
| def test_dict_with_tensor_values(self): |
| t1 = torch.randn(3) |
| t2 = torch.randn(5) |
| value = {"a": t1, "b": t2, "meta": "not a tensor"} |
| result = _map_tensor(value, lambda x: x.clone()) |
| assert torch.equal(result["a"], t1) |
| assert torch.equal(result["b"], t2) |
| assert result["a"] is not t1 |
| assert result["b"] is not t2 |
| assert result["meta"] == "not a tensor" |
|
|
| def test_dict_no_tensors(self): |
| value = {"a": 1, "b": "hello"} |
| result = _map_tensor(value, lambda x: x.clone()) |
| assert result == value |
|
|
| def test_nested_dict(self): |
| inner_t = torch.randn(3) |
| value = {"outer": {"inner": inner_t, "label": "ok"}, "top": torch.randn(2)} |
| result = _map_tensor(value, lambda x: x.clone()) |
| assert torch.equal(result["outer"]["inner"], inner_t) |
| assert result["outer"]["inner"] is not inner_t |
| assert result["outer"]["label"] == "ok" |
| assert result is not value |
| assert result["outer"] is not value["outer"] |
|
|
| def test_non_tensor_non_dict(self): |
| result = _map_tensor(42, lambda x: x.clone()) |
| assert result == 42 |
|
|
|
|
| class TestTorchSave: |
| def test_normal(self, tmp_path): |
| path = str(tmp_path / "a.pt") |
| tensor = torch.randn(3, 3) |
|
|
| _torch_save(tensor, path) |
|
|
| assert torch.equal(torch.load(path, weights_only=True), tensor) |
|
|
| def test_parameter_fallback(self, tmp_path): |
| class BadParam(torch.nn.Parameter): |
| def __reduce_ex__(self, protocol): |
| raise RuntimeError("not pickleable") |
|
|
| path = str(tmp_path / "b.pt") |
| param = BadParam(torch.randn(4)) |
|
|
| _torch_save(param, path) |
|
|
| assert torch.equal(torch.load(path, weights_only=True), param.data) |
|
|
| def test_shared_storage_not_bloated(self, tmp_path): |
| big = torch.randn(1000, 1000) |
| view = big[0] |
| path = str(tmp_path / "view.pt") |
|
|
| _torch_save({"value": view, "meta": {}}, path) |
|
|
| file_size = Path(path).stat().st_size |
| expected_max = view.nelement() * view.element_size() * 10 |
| assert file_size < expected_max, ( |
| f"File {file_size} bytes but view is only " |
| f"{view.nelement() * view.element_size()} bytes — " |
| f"torch.save likely serialized the full " |
| f"{big.nelement() * big.element_size()} byte storage" |
| ) |
|
|
| def test_silent_skip(self, tmp_path, capsys): |
| path = str(tmp_path / "c.pt") |
|
|
| _torch_save({"fn": lambda: None}, path) |
|
|
| captured = capsys.readouterr() |
| assert "[Dumper] Observe error=" in captured.out |
| assert "skip the tensor" in captured.out |
|
|
|
|
| class TestCollectiveTimeout: |
| def test_watchdog_fires_on_timeout(self): |
| block_event = threading.Event() |
| output = "" |
|
|
| def run_with_timeout(): |
| nonlocal output |
| with _capture_stdout() as captured: |
| _collective_with_timeout( |
| lambda: block_event.wait(), |
| operation_name="test_blocked_op", |
| timeout_seconds=2, |
| ) |
| output = captured.getvalue() |
|
|
| worker = threading.Thread(target=run_with_timeout) |
| worker.start() |
|
|
| time.sleep(4) |
| block_event.set() |
| worker.join(timeout=5) |
|
|
| print(f"Captured output: {output!r}") |
| assert "WARNING" in output |
| assert "test_blocked_op" in output |
| assert "2s" in output |
|
|
|
|
| class TestDumperDistributed: |
| def test_basic(self, tmp_path): |
| with temp_set_env( |
| DUMPER_ENABLE="1", |
| DUMPER_DIR=str(tmp_path), |
| ): |
| run_distributed_test(self._test_basic_func, tmpdir=str(tmp_path)) |
|
|
| @staticmethod |
| def _test_basic_func(rank, tmpdir): |
| tensor = torch.randn(10, 10, device=f"cuda:{rank}") |
|
|
| dumper.dump("tensor_a", tensor, arg=100) |
| dumper.step() |
|
|
| dumper.set_ctx(ctx_arg=200) |
| dumper.dump("tensor_b", tensor) |
| dumper.set_ctx(ctx_arg=None) |
| dumper.step() |
|
|
| dumper.configure(filter="False") |
| dumper.dump("tensor_skip", tensor) |
| dumper.configure(filter=None) |
| dumper.step() |
|
|
| dumper.dump_dict("obj", {"a": torch.randn(3, device=f"cuda:{rank}"), "b": 42}) |
| dumper.step() |
|
|
| dist.barrier() |
| filenames = _get_filenames(tmpdir) |
| _assert_files( |
| filenames, |
| exist=["tensor_a", "tensor_b", "arg=100", "ctx_arg=200", "obj_a", "obj_b"], |
| not_exist=["tensor_skip"], |
| ) |
|
|
| def test_collective_timeout(self): |
| with temp_set_env(DUMPER_ENABLE="1"): |
| run_distributed_test(self._test_collective_timeout_func) |
|
|
| @staticmethod |
| def _test_collective_timeout_func(rank): |
| dumper = _Dumper( |
| config=DumperConfig( |
| enable=True, |
| collective_timeout=3, |
| ), |
| ) |
|
|
| with _capture_stdout() as captured: |
| if rank != 0: |
| time.sleep(6) |
| dumper.step() |
|
|
| output = captured.getvalue() |
| print(f"Rank {rank} captured output: {output!r}") |
|
|
| if rank == 0: |
| assert "WARNING" in output, f"Expected WARNING in rank 0 output: {output}" |
| assert "has not completed after 3s" in output |
|
|
| def test_file_content_correctness(self, tmp_path): |
| with temp_set_env( |
| DUMPER_ENABLE="1", |
| DUMPER_DIR=str(tmp_path), |
| ): |
| run_distributed_test(self._test_file_content_func, tmpdir=str(tmp_path)) |
|
|
| @staticmethod |
| def _test_file_content_func(rank, tmpdir): |
| tensor = torch.arange(12, device=f"cuda:{rank}").reshape(3, 4).float() |
|
|
| dumper.dump("content_check", tensor) |
| dumper.step() |
|
|
| dist.barrier() |
| path = _find_dump_file(tmpdir, rank=rank, name="content_check") |
| raw = _load_dump(path) |
| assert isinstance(raw, dict), f"Expected dict, got {type(raw)}" |
| assert "value" in raw and "meta" in raw |
| assert torch.equal(raw["value"], tensor.cpu()) |
| assert raw["meta"]["name"] == "content_check" |
| assert raw["meta"]["rank"] == rank |
|
|
|
|
| class TestDumperFileWriteControl: |
| def test_filter(self, tmp_path): |
| with temp_set_env( |
| DUMPER_ENABLE="1", |
| DUMPER_DIR=str(tmp_path), |
| DUMPER_FILTER="name.startswith('keep')", |
| ): |
| run_distributed_test(self._test_filter_func, tmpdir=str(tmp_path)) |
|
|
| @staticmethod |
| def _test_filter_func(rank, tmpdir): |
| dumper.dump("keep_this", torch.randn(5, device=f"cuda:{rank}")) |
| dumper.dump("skip_this", torch.randn(5, device=f"cuda:{rank}")) |
| dumper.dump("not_keep_this", torch.randn(5, device=f"cuda:{rank}")) |
| dumper.step() |
|
|
| dist.barrier() |
| filenames = _get_filenames(tmpdir) |
| _assert_files( |
| filenames, |
| exist=["keep_this"], |
| not_exist=["skip_this", "not_keep_this"], |
| ) |
|
|
| def test_save_false(self, tmp_path): |
| with temp_set_env( |
| DUMPER_ENABLE="1", |
| DUMPER_DIR=str(tmp_path), |
| ): |
| run_distributed_test(self._test_save_false_func, tmpdir=str(tmp_path)) |
|
|
| @staticmethod |
| def _test_save_false_func(rank, tmpdir): |
| dumper.dump("no_save_tensor", torch.randn(5, device=f"cuda:{rank}"), save=False) |
| dumper.step() |
|
|
| dist.barrier() |
| assert len(_get_filenames(tmpdir)) == 0 |
|
|
|
|
| class TestDumpEnableFlags: |
| def test_all_enables_false_no_output(self, tmp_path): |
| d = _make_test_dumper(tmp_path, enable_value=False, enable_grad=False) |
| d.dump("should_skip", torch.randn(3, 3)) |
| assert len(_get_filenames(tmp_path)) == 0 |
|
|
|
|
| class TestOutputControl: |
| def test_file_enabled_by_default(self, tmp_path): |
| d = _make_test_dumper(tmp_path) |
| d.dump("file_on", torch.randn(3, 3)) |
|
|
| _assert_files(_get_filenames(tmp_path), exist=["file_on"]) |
|
|
| def test_file_disabled(self, tmp_path, capsys): |
| d = _make_test_dumper(tmp_path, enable_output_file=False) |
| d.dump("file_off", torch.randn(3, 3)) |
|
|
| assert len(_get_filenames(tmp_path)) == 0 |
| assert "file_off" in capsys.readouterr().out |
|
|
| def test_console_enabled_by_default(self, tmp_path, capsys): |
| d = _make_test_dumper(tmp_path) |
| d.dump("console_on", torch.randn(3, 3)) |
|
|
| captured = capsys.readouterr() |
| assert "[Dumper.Value]" in captured.out |
| assert "console_on" in captured.out |
|
|
| def test_console_disabled(self, tmp_path, capsys): |
| d = _make_test_dumper(tmp_path, enable_output_console=False) |
| d.dump("console_off", torch.randn(3, 3)) |
|
|
| assert "console_off" not in capsys.readouterr().out |
| _assert_files(_get_filenames(tmp_path), exist=["console_off"]) |
|
|
| def test_capture_output_basic(self, tmp_path): |
| d = _make_test_dumper(tmp_path) |
| tensor = torch.randn(4, 4) |
|
|
| with d.capture_output() as captured: |
| d.dump("cap_basic", tensor) |
|
|
| assert "cap_basic" in captured |
| assert set(captured["cap_basic"].keys()) == {"value", "meta"} |
| assert torch.equal(captured["cap_basic"]["value"], tensor) |
| assert captured["cap_basic"]["meta"]["name"] == "cap_basic" |
|
|
| def test_capture_output_no_file(self, tmp_path): |
| d = _make_test_dumper(tmp_path) |
|
|
| with d.capture_output() as captured: |
| d.dump("cap_no_file", torch.randn(3, 3)) |
|
|
| assert "cap_no_file" in captured |
| assert len(_get_filenames(tmp_path)) == 0 |
|
|
| def test_capture_output_multiple(self, tmp_path): |
| d = _make_test_dumper(tmp_path) |
|
|
| with d.capture_output() as captured: |
| d.dump("first", torch.randn(2, 2)) |
| d.dump("second", torch.randn(3, 3)) |
|
|
| assert set(captured.keys()) == {"first", "second"} |
| assert captured["first"]["value"].shape == (2, 2) |
| assert captured["second"]["value"].shape == (3, 3) |
|
|
| def test_capture_output_value_cloned(self, tmp_path): |
| d = _make_test_dumper(tmp_path) |
| tensor = torch.zeros(3, 3) |
|
|
| with d.capture_output() as captured: |
| d.dump("clone_check", tensor) |
|
|
| tensor.fill_(999.0) |
| assert torch.equal(captured["clone_check"]["value"], torch.zeros(3, 3)) |
|
|
| def test_capture_output_nested_raises(self, tmp_path): |
| d = _make_test_dumper(tmp_path) |
| with d.capture_output(): |
| with pytest.raises(AssertionError): |
| with d.capture_output(): |
| pass |
|
|
| def test_capture_output_respects_filter(self, tmp_path): |
| d = _make_test_dumper(tmp_path, filter="'keep' in name") |
|
|
| with d.capture_output() as captured: |
| d.dump("keep_this", torch.randn(3, 3)) |
| d.dump("skip_this", torch.randn(3, 3)) |
|
|
| assert "keep_this" in captured |
| assert "skip_this" not in captured |
|
|
|
|
| class TestDumpDictFormat: |
| """Verify that dump files use the dict output format: {"value": ..., "meta": {...}}.""" |
|
|
| def test_dict_format_structure(self, tmp_path): |
| dumper = _make_test_dumper(tmp_path) |
| tensor = torch.randn(4, 4) |
| dumper.dump("fmt_test", tensor, custom_key="hello") |
|
|
| path = _find_dump_file(str(tmp_path), rank=0, name="fmt_test") |
| raw = _load_dump(path) |
|
|
| assert isinstance(raw, dict) |
| assert set(raw.keys()) == {"value", "meta"} |
| assert torch.equal(raw["value"], tensor) |
|
|
| meta = raw["meta"] |
| assert meta["name"] == "fmt_test" |
| assert meta["custom_key"] == "hello" |
| assert "step" in meta |
| assert "rank" in meta |
| assert "dump_index" in meta |
|
|
| def test_dict_format_with_context(self, tmp_path): |
| dumper = _make_test_dumper(tmp_path) |
| dumper.set_ctx(ctx_val=42) |
| tensor = torch.randn(2, 2) |
| dumper.dump("ctx_fmt", tensor) |
|
|
| path = _find_dump_file(str(tmp_path), rank=0, name="ctx_fmt") |
| raw = _load_dump(path) |
|
|
| assert raw["meta"]["ctx_val"] == 42 |
| assert torch.equal(raw["value"], tensor) |
|
|
|
|
| def _make_test_dumper(tmp_path, **overrides) -> _Dumper: |
| """Create a _Dumper for CPU testing without distributed.""" |
| defaults = dict( |
| enable=True, |
| dir=str(tmp_path), |
| exp_name="test", |
| ) |
| defaults.update(overrides) |
| config = DumperConfig(**defaults) |
| return _Dumper(config=config) |
|
|
|
|
| def _get_filenames(tmpdir): |
| return {f.name for f in Path(tmpdir).glob("*/*.pt")} |
|
|
|
|
| def _assert_files(filenames, *, exist=(), not_exist=()): |
| for p in exist: |
| assert any(p in f for f in filenames), f"{p} not found in {filenames}" |
| for p in not_exist: |
| assert not any( |
| p in f for f in filenames |
| ), f"{p} should not exist in {filenames}" |
|
|
|
|
| def _load_dump(path: Path) -> dict: |
| """Load a dump file and return the raw dict (with 'value' and 'meta' keys).""" |
| return torch.load(path, map_location="cpu", weights_only=False) |
|
|
|
|
| def _find_dump_file(tmpdir, *, rank: int = 0, name: str) -> Path: |
| matches = [ |
| f |
| for f in Path(tmpdir).glob("*/*.pt") |
| if f"rank={rank}" in f.name and name in f.name |
| ] |
| assert ( |
| len(matches) == 1 |
| ), f"Expected 1 file matching rank={rank} name={name}, got {matches}" |
| return matches[0] |
|
|
|
|
| class TestMaterializeValue: |
| def test_materialize_value_callable(self): |
| tensor = torch.randn(3, 3) |
| result = _materialize_value(lambda: tensor) |
| assert torch.equal(result, tensor) |
|
|
| def test_materialize_value_passthrough(self): |
| tensor = torch.randn(3, 3) |
| result = _materialize_value(tensor) |
| assert result is tensor |
|
|
| def test_dump_with_callable_value(self, tmp_path): |
| d = _make_test_dumper(tmp_path) |
| tensor = torch.randn(4, 4) |
| d.dump("lazy_tensor", lambda: tensor) |
|
|
| _assert_files(_get_filenames(tmp_path), exist=["name=lazy_tensor"]) |
|
|
| path = _find_dump_file(tmp_path, rank=0, name="lazy_tensor") |
| assert torch.equal(_load_dump(path)["value"], tensor) |
|
|
|
|
| class TestSaveValue: |
| def test_dump_output_format(self, tmp_path): |
| dumper = _make_test_dumper(tmp_path) |
| tensor = torch.randn(4, 4) |
|
|
| dumper.dump("dict_test", tensor) |
|
|
| path = _find_dump_file(tmp_path, rank=0, name="dict_test") |
| loaded = _load_dump(path) |
| assert torch.equal(loaded["value"], tensor) |
| assert loaded["meta"]["name"] == "dict_test" |
| assert loaded["meta"]["rank"] == 0 |
|
|
|
|
| class TestStaticMetadata: |
| def test_static_meta_contains_world_info(self): |
| dumper = _make_test_dumper("/tmp") |
| meta = dumper._static_meta |
| assert "world_rank" in meta |
| assert "world_size" in meta |
| assert meta["world_rank"] == 0 |
| assert meta["world_size"] == 1 |
|
|
| def test_static_meta_caching(self): |
| dumper = _make_test_dumper("/tmp") |
| meta1 = dumper._static_meta |
| meta2 = dumper._static_meta |
| assert meta1 is meta2 |
|
|
| def test_parallel_info_graceful_fallback(self): |
| sglang_info = _SGLangPlugin().collect_parallel_info() |
| assert isinstance(sglang_info, dict) |
|
|
| megatron_info = _MegatronPlugin().collect_parallel_info() |
| assert isinstance(megatron_info, dict) |
|
|
| def test_dump_includes_static_meta(self, tmp_path): |
| dumper = _make_test_dumper(tmp_path) |
| tensor = torch.randn(2, 2) |
|
|
| dumper.dump("meta_test", tensor) |
|
|
| path = _find_dump_file(tmp_path, rank=0, name="meta_test") |
| loaded = _load_dump(path) |
| meta = loaded["meta"] |
| assert "world_rank" in meta |
| assert "world_size" in meta |
|
|
|
|
| class TestDumpGrad: |
| def test_dump_grad_basic(self, tmp_path): |
| d = _make_test_dumper(tmp_path, enable_grad=True) |
| x = torch.randn(3, 3, requires_grad=True) |
| y = (x * 2).sum() |
|
|
| d.dump("test_tensor", x) |
| y.backward() |
|
|
| filenames = _get_filenames(tmp_path) |
| assert any("name=test_tensor" in f and "grad__" not in f for f in filenames) |
| _assert_files(filenames, exist=["grad__test_tensor"]) |
|
|
| def test_dump_grad_non_tensor_skipped(self, tmp_path): |
| d = _make_test_dumper(tmp_path, enable_grad=True) |
| d.dump("not_tensor", 42) |
|
|
| _assert_files(_get_filenames(tmp_path), not_exist=["grad__"]) |
|
|
| def test_dump_grad_no_requires_grad_skipped(self, tmp_path): |
| d = _make_test_dumper(tmp_path, enable_grad=True) |
| x = torch.randn(3, 3, requires_grad=False) |
| d.dump("no_grad_tensor", x) |
|
|
| _assert_files( |
| _get_filenames(tmp_path), |
| exist=["name=no_grad_tensor"], |
| not_exist=["grad__"], |
| ) |
|
|
| def test_dump_grad_captures_step(self, tmp_path): |
| d = _make_test_dumper(tmp_path, enable_grad=True) |
| d._state.step = 42 |
| x = torch.randn(3, 3, requires_grad=True) |
| y = (x * 2).sum() |
|
|
| d.dump("id_test", x) |
| d._state.step = 999 |
| y.backward() |
|
|
| grad_file = _find_dump_file(tmp_path, name="grad__id_test") |
| assert "step=42" in grad_file.name |
|
|
| def test_dump_grad_file_content(self, tmp_path): |
| d = _make_test_dumper(tmp_path, enable_grad=True) |
| x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True) |
| y = (x * 3).sum() |
|
|
| d.dump("content_check", x) |
| y.backward() |
|
|
| grad_path = _find_dump_file(tmp_path, name="grad__content_check") |
| expected_grad = torch.full((2, 2), 3.0) |
| assert torch.equal(_load_dump(grad_path)["value"], expected_grad) |
|
|
| def test_disable_value(self, tmp_path): |
| d = _make_test_dumper(tmp_path, enable_value=False, enable_grad=True) |
| x = torch.randn(3, 3, requires_grad=True) |
| y = (x * 2).sum() |
|
|
| d.dump("fwd_disabled", x) |
| y.backward() |
|
|
| filenames = _get_filenames(tmp_path) |
| assert not any( |
| "name=fwd_disabled" in f and "grad__" not in f for f in filenames |
| ) |
| _assert_files(filenames, exist=["grad__fwd_disabled"]) |
|
|
| def test_disable_grad(self, tmp_path): |
| d = _make_test_dumper(tmp_path, enable_grad=False) |
| x = torch.randn(3, 3, requires_grad=True) |
| y = (x * 2).sum() |
|
|
| d.dump("grad_disabled", x) |
| y.backward() |
|
|
| _assert_files( |
| _get_filenames(tmp_path), |
| exist=["name=grad_disabled"], |
| not_exist=["grad__"], |
| ) |
|
|
|
|
| class TestKvFilter: |
| def test_format_tags(self): |
| assert _format_tags({"a": 1, "b": "hello"}) == "a=1___b=hello" |
| assert _format_tags({}) == "" |
|
|
| def test_filter_matches_extra_kwargs(self, tmp_path): |
| d = _make_test_dumper(tmp_path, filter="layer_id == 0") |
| d.dump("tensor_a", torch.randn(3), layer_id=0) |
| d.dump("tensor_b", torch.randn(3), layer_id=1) |
|
|
| filenames = _get_filenames(tmp_path) |
| _assert_files(filenames, exist=["tensor_a"], not_exist=["tensor_b"]) |
|
|
| def test_filter_matches_global_ctx(self, tmp_path): |
| d = _make_test_dumper(tmp_path, filter="ctx_arg == 200") |
| d.set_ctx(ctx_arg=200) |
| d.dump("tensor_a", torch.randn(3)) |
| d.set_ctx(ctx_arg=None) |
| d.dump("tensor_b", torch.randn(3)) |
|
|
| filenames = _get_filenames(tmp_path) |
| _assert_files(filenames, exist=["tensor_a"], not_exist=["tensor_b"]) |
|
|
| def test_filter_matches_name(self, tmp_path): |
| d = _make_test_dumper(tmp_path, filter="'keep' in name") |
| d.dump("keep_this", torch.randn(3)) |
| d.dump("skip_this", torch.randn(3)) |
|
|
| filenames = _get_filenames(tmp_path) |
| _assert_files(filenames, exist=["keep_this"], not_exist=["skip_this"]) |
|
|
| def test_filter_expr_range(self, tmp_path): |
| d = _make_test_dumper(tmp_path, filter="layer_id is not None and layer_id < 3") |
| d.dump("t0", torch.randn(3), layer_id=0) |
| d.dump("t1", torch.randn(3), layer_id=1) |
| d.dump("t5", torch.randn(3), layer_id=5) |
|
|
| filenames = _get_filenames(tmp_path) |
| _assert_files(filenames, exist=["name=t0", "name=t1"], not_exist=["name=t5"]) |
|
|
| def test_filter_expr_with_none(self, tmp_path): |
| d = _make_test_dumper(tmp_path, filter="layer_id is None or layer_id < 3") |
| d.dump("no_layer", torch.randn(3)) |
| d.dump("layer0", torch.randn(3), layer_id=0) |
| d.dump("layer5", torch.randn(3), layer_id=5) |
|
|
| filenames = _get_filenames(tmp_path) |
| _assert_files( |
| filenames, |
| exist=["no_layer", "layer0"], |
| not_exist=["layer5"], |
| ) |
|
|
| def test_filter_expr_with_re_search(self, tmp_path): |
| d = _make_test_dumper(tmp_path, filter="search(r'attn|mlp', name)") |
| d.dump("self_attn", torch.randn(3)) |
| d.dump("mlp_proj", torch.randn(3)) |
| d.dump("layernorm", torch.randn(3)) |
|
|
| filenames = _get_filenames(tmp_path) |
| _assert_files( |
| filenames, |
| exist=["self_attn", "mlp_proj"], |
| not_exist=["layernorm"], |
| ) |
|
|
| def test_filter_expr_syntax_error(self, tmp_path): |
| d = _make_test_dumper(tmp_path, filter="layer_id ===") |
| with pytest.raises(SyntaxError): |
| d.dump("tensor", torch.randn(3)) |
|
|
| def test_no_filter_dumps_all(self, tmp_path): |
| d = _make_test_dumper(tmp_path) |
| d.dump("a", torch.randn(3)) |
| d.dump("b", torch.randn(3)) |
|
|
| filenames = _get_filenames(tmp_path) |
| _assert_files(filenames, exist=["name=a", "name=b"]) |
|
|
|
|
| class TestDumpModel: |
| def test_grad_basic(self, tmp_path): |
| d = _make_test_dumper( |
| tmp_path, enable_model_grad=True, enable_model_value=False |
| ) |
| model = torch.nn.Linear(4, 2) |
| x = torch.randn(3, 4) |
| y = model(x).sum() |
| y.backward() |
|
|
| d.dump_model(model, name_prefix="model") |
|
|
| _assert_files( |
| _get_filenames(tmp_path), |
| exist=["grad__model__weight", "grad__model__bias"], |
| ) |
|
|
| def test_value_basic(self, tmp_path): |
| d = _make_test_dumper( |
| tmp_path, enable_model_value=True, enable_model_grad=False |
| ) |
| model = torch.nn.Linear(4, 2, bias=False) |
|
|
| d.dump_model(model, name_prefix="model") |
|
|
| _assert_files( |
| _get_filenames(tmp_path), |
| exist=["model__weight"], |
| ) |
|
|
| def test_no_grad_skipped(self, tmp_path): |
| d = _make_test_dumper( |
| tmp_path, enable_model_grad=True, enable_model_value=False |
| ) |
| model = torch.nn.Linear(4, 2) |
|
|
| d.dump_model(model, name_prefix="model") |
|
|
| filenames = _get_filenames(tmp_path) |
| assert len(filenames) == 0 |
|
|
| def test_filter(self, tmp_path): |
| d = _make_test_dumper( |
| tmp_path, |
| enable_model_value=True, |
| enable_model_grad=True, |
| filter="'weight' in name", |
| ) |
| model = torch.nn.Linear(4, 2) |
| x = torch.randn(3, 4) |
| y = model(x).sum() |
| y.backward() |
|
|
| d.dump_model(model, name_prefix="model") |
|
|
| _assert_files( |
| _get_filenames(tmp_path), |
| exist=["model__weight", "grad__model__weight"], |
| not_exist=["model__bias", "grad__model__bias"], |
| ) |
|
|
| def test_grad_file_content(self, tmp_path): |
| d = _make_test_dumper( |
| tmp_path, enable_model_grad=True, enable_model_value=False |
| ) |
| model = torch.nn.Linear(4, 2, bias=False) |
| x = torch.ones(1, 4) |
| y = model(x).sum() |
| y.backward() |
|
|
| d.dump_model(model, name_prefix="p") |
|
|
| path = _find_dump_file(tmp_path, name="grad__p__weight") |
| assert torch.equal(_load_dump(path)["value"], model.weight.grad) |
|
|
| def test_disable_model_grad(self, tmp_path): |
| d = _make_test_dumper( |
| tmp_path, enable_model_value=True, enable_model_grad=False |
| ) |
| model = torch.nn.Linear(4, 2) |
| x = torch.randn(3, 4) |
| y = model(x).sum() |
| y.backward() |
|
|
| d.dump_model(model, name_prefix="model") |
|
|
| filenames = _get_filenames(tmp_path) |
| assert all("grad" not in f for f in filenames) |
|
|
| def test_parameter_saved_as_parameter(self, tmp_path): |
| d = _make_test_dumper( |
| tmp_path, enable_model_value=True, enable_model_grad=False |
| ) |
| model = torch.nn.Linear(4, 2, bias=False) |
|
|
| d.dump_model(model, name_prefix="p") |
|
|
| path = _find_dump_file(tmp_path, name="p__weight") |
| loaded = _load_dump(path) |
| assert isinstance(loaded["value"], torch.nn.Parameter) |
| assert torch.equal(loaded["value"], model.weight) |
|
|
| def test_unpicklable_parameter_falls_back_to_data(self, tmp_path): |
| class BadParam(torch.nn.Parameter): |
| def __reduce_ex__(self, protocol): |
| raise RuntimeError("not pickleable") |
|
|
| d = _make_test_dumper( |
| tmp_path, enable_model_value=True, enable_model_grad=False |
| ) |
| model = torch.nn.Linear(4, 2, bias=False) |
| model.weight = BadParam(model.weight.data) |
|
|
| d.dump_model(model, name_prefix="p") |
|
|
| path = _find_dump_file(tmp_path, name="p__weight") |
| loaded = _load_dump(path) |
| assert isinstance(loaded["value"], torch.Tensor) |
| assert not isinstance(loaded["value"], torch.nn.Parameter) |
| assert torch.equal(loaded["value"], model.weight.data) |
|
|
| def test_disable_model_value(self, tmp_path): |
| d = _make_test_dumper( |
| tmp_path, enable_model_grad=True, enable_model_value=False |
| ) |
| model = torch.nn.Linear(4, 2, bias=False) |
| x = torch.ones(1, 4) |
| y = model(x).sum() |
| y.backward() |
|
|
| d.dump_model(model, name_prefix="model") |
|
|
| filenames = _get_filenames(tmp_path) |
| assert all("grad" in f for f in filenames) |
|
|
|
|
| class TestCleanup: |
| def test_cleanup_removes_old_dumps(self, tmp_path): |
| old_dir = tmp_path / "dump_old" |
| old_dir.mkdir() |
| (old_dir / "dummy.pt").touch() |
|
|
| dumper = _make_test_dumper(tmp_path, cleanup_previous=True) |
| dumper.dump("new_tensor", torch.randn(3, 3)) |
|
|
| assert not old_dir.exists() |
| _assert_files(_get_filenames(tmp_path), exist=["new_tensor"]) |
|
|
| def test_cleanup_removes_exp_name_dir(self, tmp_path): |
| exp_name = "my_custom_exp" |
| old_exp_dir = tmp_path / exp_name |
| old_exp_dir.mkdir() |
| (old_exp_dir / "old_data.pt").touch() |
|
|
| dumper = _make_test_dumper(tmp_path, exp_name=exp_name, cleanup_previous=True) |
| dumper.dump("new_tensor", torch.randn(3, 3)) |
|
|
| assert not (tmp_path / exp_name / "old_data.pt").exists() |
| _assert_files(_get_filenames(tmp_path), exist=["new_tensor"]) |
|
|
| def test_cleanup_removes_both_dump_prefix_and_exp_name(self, tmp_path): |
| old_dump = tmp_path / "dump_old" |
| old_dump.mkdir() |
| (old_dump / "dummy.pt").touch() |
|
|
| exp_name = "custom_run" |
| old_exp = tmp_path / exp_name |
| old_exp.mkdir() |
| (old_exp / "stale.pt").touch() |
|
|
| dumper = _make_test_dumper(tmp_path, exp_name=exp_name, cleanup_previous=True) |
| dumper.dump("new_tensor", torch.randn(3, 3)) |
|
|
| assert not old_dump.exists() |
| assert not (tmp_path / exp_name / "stale.pt").exists() |
| _assert_files(_get_filenames(tmp_path), exist=["new_tensor"]) |
|
|
| def test_no_cleanup_by_default(self, tmp_path): |
| old_dir = tmp_path / "dump_old" |
| old_dir.mkdir() |
| (old_dir / "dummy.pt").touch() |
|
|
| dumper = _make_test_dumper(tmp_path) |
| dumper.dump("new_tensor", torch.randn(3, 3)) |
|
|
| assert old_dir.exists() |
| _assert_files(_get_filenames(tmp_path), exist=["new_tensor"]) |
|
|
|
|
| class TestReset: |
| def test_reset_clears_state(self, tmp_path): |
| d = _make_test_dumper(tmp_path) |
| d.set_ctx(layer_id=1) |
| d.dump("before_reset", torch.randn(3, 3)) |
|
|
| d.reset() |
|
|
| assert d._state.dump_index == 0 |
| assert d._state.step == 0 |
| assert d._state.global_ctx == {} |
|
|
| def test_dump_works_after_reset(self, tmp_path): |
| d = _make_test_dumper(tmp_path) |
| d.dump("pre", torch.randn(3, 3)) |
|
|
| d.reset() |
| d.dump("post", torch.randn(3, 3)) |
|
|
| filenames = _get_filenames(tmp_path) |
| _assert_files(filenames, exist=["pre", "post"]) |
| post_file = _find_dump_file(tmp_path, name="post") |
| assert "dump_index=1" in post_file.name |
|
|
| def test_cleanup_previous_re_triggers_after_reset(self, tmp_path): |
| """Miles pattern: reset() + configure(cleanup_previous=True) should re-clean.""" |
| exp_alpha = "exp_alpha" |
| exp_beta = "exp_beta" |
|
|
| (tmp_path / exp_alpha).mkdir() |
| (tmp_path / exp_alpha / "stale.pt").touch() |
| (tmp_path / exp_beta).mkdir() |
| (tmp_path / exp_beta / "stale.pt").touch() |
|
|
| d = _make_test_dumper(tmp_path, exp_name=exp_alpha, cleanup_previous=True) |
| d.dump("phase1", torch.randn(2, 2)) |
|
|
| d.reset() |
| d.configure(exp_name=exp_beta, cleanup_previous=True) |
| d.dump("phase2", torch.randn(2, 2)) |
|
|
| assert not (tmp_path / exp_alpha / "stale.pt").exists() |
| assert not (tmp_path / exp_beta / "stale.pt").exists() |
| filenames = _get_filenames(tmp_path) |
| _assert_files(filenames, exist=["phase1", "phase2"]) |
|
|
| def test_no_cleanup_when_config_false(self, tmp_path): |
| """cleanup_previous=False: handled stays False but no cleanup runs.""" |
| old_dir = tmp_path / "dump_old" |
| old_dir.mkdir() |
| (old_dir / "dummy.pt").touch() |
|
|
| d = _make_test_dumper(tmp_path, cleanup_previous=False) |
| d.dump("tensor", torch.randn(2, 2)) |
|
|
| assert old_dir.exists() |
| assert d._state.cleanup_previous_handled is False |
|
|
| def test_multi_phase_switch(self, tmp_path): |
| """Simulate Miles multi-phase: configure → dump → reset → configure new phase → dump.""" |
| d = _make_test_dumper(tmp_path, cleanup_previous=True) |
|
|
| d.configure(exp_name="fwd_only") |
| d.dump("weight", torch.randn(2, 2)) |
| d.step() |
| d.configure(enable=False) |
|
|
| d.reset() |
| d.configure(exp_name="fwd_bwd", enable=True, cleanup_previous=True) |
| d.dump("weight", torch.randn(2, 2)) |
| d.step() |
|
|
| fwd_only_files = list(Path(tmp_path).glob("fwd_only/*.pt")) |
| fwd_bwd_files = list(Path(tmp_path).glob("fwd_bwd/*.pt")) |
| assert len(fwd_only_files) > 0 |
| assert len(fwd_bwd_files) > 0 |
| assert d._state.step == 1 |
| assert d._state.dump_index == 1 |
|
|
| def test_reset_removes_non_intrusive_hooks(self, tmp_path): |
| model = torch.nn.Sequential( |
| torch.nn.Linear(4, 4), |
| torch.nn.ReLU(), |
| torch.nn.Linear(4, 4), |
| ) |
| d = _make_test_dumper(tmp_path, non_intrusive_mode="all") |
| d.register_non_intrusive_dumper(model) |
|
|
| x = torch.randn(2, 4) |
| with d.capture_output() as captured: |
| model(x) |
| assert len(captured) > 0 |
|
|
| d.reset() |
| d.configure(enable=True, dir=str(tmp_path), non_intrusive_mode="all") |
|
|
| with d.capture_output() as captured_after: |
| model(x) |
| assert len(captured_after) == 0 |
|
|
| def test_reset_removes_non_intrusive_hooks_multiple_models(self, tmp_path): |
| model_a = torch.nn.Sequential( |
| torch.nn.Linear(4, 4), |
| torch.nn.ReLU(), |
| ) |
| model_b = torch.nn.Sequential( |
| torch.nn.Linear(4, 4), |
| torch.nn.ReLU(), |
| ) |
| d = _make_test_dumper(tmp_path, non_intrusive_mode="all") |
| d.register_non_intrusive_dumper(model_a) |
| d.register_non_intrusive_dumper(model_b) |
|
|
| x = torch.randn(2, 4) |
| with d.capture_output() as captured: |
| model_a(x) |
| model_b(x) |
| assert len(captured) > 0 |
|
|
| d.reset() |
| d.configure(enable=True, dir=str(tmp_path), non_intrusive_mode="all") |
|
|
| with d.capture_output() as captured_a: |
| model_a(x) |
| assert len(captured_a) == 0 |
|
|
| with d.capture_output() as captured_b: |
| model_b(x) |
| assert len(captured_b) == 0 |
|
|
|
|
| def _dumper_worker(rank, http_port: int, stop_event): |
| """Minimal distributed dumper worker: configure, step (triggers ZMQ setup), then wait.""" |
| dumper.configure(enable=False, server_port=str(http_port)) |
| dumper.step() |
| stop_event.wait() |
|
|
|
|
| def _wait_for_dumper_http(url: str, timeout: float = 30) -> None: |
| deadline = time.time() + timeout |
| while time.time() < deadline: |
| try: |
| requests.post(f"{url}/dumper/configure", json={}, timeout=2) |
| return |
| except requests.ConnectionError: |
| time.sleep(0.5) |
| raise TimeoutError(f"Dumper HTTP server not reachable at {url}") |
|
|
|
|
| class TestZmqPortIsolation: |
| """Multiple independent dumper instances (each with 2 ranks) must not conflict on ZMQ ports.""" |
|
|
| NUM_INSTANCES = 3 |
|
|
| def test_concurrent_instances_no_port_conflict(self): |
| ports = [ |
| find_available_port(40000 + i * 1000) for i in range(self.NUM_INSTANCES) |
| ] |
| stop_events = [] |
| threads = [] |
| ctx = multiprocessing.get_context("spawn") |
|
|
| for port in ports: |
| stop_event = ctx.Event() |
| stop_events.append(stop_event) |
| thread = threading.Thread( |
| target=run_distributed_test, |
| args=(_dumper_worker,), |
| kwargs={"http_port": port, "stop_event": stop_event}, |
| ) |
| thread.start() |
| threads.append(thread) |
|
|
| try: |
| for port in ports: |
| _wait_for_dumper_http(f"http://127.0.0.1:{port}") |
|
|
| for i, port in enumerate(ports): |
| resp = requests.post( |
| f"http://127.0.0.1:{port}/dumper/get_state", json={} |
| ) |
| resp.raise_for_status() |
| states = resp.json() |
| assert ( |
| len(states) == 2 |
| ), f"Instance {i} (port {port}): expected 2 ranks, got {len(states)}" |
| finally: |
| for event in stop_events: |
| event.set() |
| for thread in threads: |
| thread.join(timeout=10) |
|
|
|
|
| class TestDumperHttp: |
| """Test /dumper/* HTTP control — parametrized over standalone vs sglang server.""" |
|
|
| @pytest.fixture(scope="class", params=["standalone", "sglang"]) |
| def dumper_http_url(self, request): |
| if request.param == "standalone": |
| http_port = find_available_port(40000) |
| base_url = f"http://127.0.0.1:{http_port}" |
| stop_event = multiprocessing.get_context("spawn").Event() |
| thread = threading.Thread( |
| target=run_distributed_test, |
| args=(_dumper_worker,), |
| kwargs={"http_port": http_port, "stop_event": stop_event}, |
| ) |
| thread.start() |
| try: |
| _wait_for_dumper_http(base_url) |
| yield base_url |
| finally: |
| stop_event.set() |
| thread.join(timeout=10) |
| else: |
| base_url = DEFAULT_URL_FOR_TEST |
| env = {**os.environ, "DUMPER_SERVER_PORT": "reuse"} |
| proc = popen_launch_server( |
| "Qwen/Qwen3-0.6B", |
| base_url, |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| other_args=["--max-total-tokens", "128"], |
| env=env, |
| ) |
| try: |
| yield base_url |
| finally: |
| kill_process_tree(proc.pid) |
|
|
| @staticmethod |
| def _post(base_url: str, method: str, **kwargs) -> list[dict]: |
| resp = requests.post(f"{base_url}/dumper/{method}", json=kwargs or None) |
| resp.raise_for_status() |
| states = resp.json() |
| assert isinstance(states, list) and len(states) >= 1 |
| return states |
|
|
| @staticmethod |
| def _assert_all_ranks(states: list[dict], path: str, expected): |
| """Assert that ``state[path]`` equals ``expected`` on every rank.""" |
| keys = path.split(".") |
| for rank, state in enumerate(states): |
| val = state |
| for k in keys: |
| val = val[k] |
| assert ( |
| val == expected |
| ), f"rank {rank}: {path}={val!r}, expected {expected!r}" |
|
|
| def test_configure_enable_toggle(self, dumper_http_url: str): |
| for enable in [True, False]: |
| self._post(dumper_http_url, "configure", enable=enable) |
| states = self._post(dumper_http_url, "get_state") |
| self._assert_all_ranks(states, "config.enable", enable) |
|
|
| def test_configure_multi_field(self, dumper_http_url: str): |
| self._post( |
| dumper_http_url, |
| "configure", |
| enable=True, |
| filter="layer_id == 0", |
| dir="/tmp/test_http", |
| ) |
| states = self._post(dumper_http_url, "get_state") |
| self._assert_all_ranks(states, "config.enable", True) |
| self._assert_all_ranks(states, "config.filter", "layer_id == 0") |
| self._assert_all_ranks(states, "config.dir", "/tmp/test_http") |
|
|
| def test_configure_clear_optional(self, dumper_http_url: str): |
| self._post(dumper_http_url, "configure", filter="layer_id == 0") |
| self._post(dumper_http_url, "configure", filter=None) |
| states = self._post(dumper_http_url, "get_state") |
| self._assert_all_ranks(states, "config.filter", None) |
|
|
| def test_reset(self, dumper_http_url: str): |
| self._post(dumper_http_url, "configure", enable=True) |
| self._post(dumper_http_url, "reset") |
| states = self._post(dumper_http_url, "get_state") |
| self._assert_all_ranks(states, "dump_index", 0) |
| self._assert_all_ranks(states, "step", 0) |
|
|
| def test_get_state(self, dumper_http_url: str): |
| self._post( |
| dumper_http_url, |
| "configure", |
| enable=True, |
| filter="layer_id is not None and layer_id < 3", |
| ) |
| states = self._post(dumper_http_url, "get_state") |
| self._assert_all_ranks(states, "config.enable", True) |
| self._assert_all_ranks( |
| states, "config.filter", "layer_id is not None and layer_id < 3" |
| ) |
| for state in states: |
| assert "dump_index" in state |
| assert "step" in state |
|
|
| def test_all_ranks_consistent(self, dumper_http_url: str): |
| self._post(dumper_http_url, "configure", enable=True, dir="/tmp/multi") |
| states = self._post(dumper_http_url, "get_state") |
| configs = [s["config"] for s in states] |
| for rank_config in configs[1:]: |
| assert rank_config == configs[0], f"rank configs diverged: {configs}" |
|
|
| def test_error_unknown_field(self, dumper_http_url: str): |
| resp = requests.post( |
| f"{dumper_http_url}/dumper/configure", |
| json={"nonexistent_field": 123}, |
| ) |
| assert resp.status_code == 400 |
|
|
| def test_error_unknown_method(self, dumper_http_url: str): |
| resp = requests.post( |
| f"{dumper_http_url}/dumper/nonexistent", |
| json={}, |
| ) |
| assert resp.status_code == 400 |
|
|
| def test_error_wrong_type(self, dumper_http_url: str): |
| resp = requests.post( |
| f"{dumper_http_url}/dumper/configure", |
| json={"enable": "not_a_bool"}, |
| ) |
| assert resp.status_code == 400 |
|
|
|
|
| class TestRegisterForwardHookOrReplaceFn: |
| def test_unknown_mode_raises(self): |
| module = torch.nn.Linear(4, 4) |
| with pytest.raises(ValueError, match="Unknown mode"): |
| _register_forward_hook_or_replace_fn( |
| module, |
| pre_hook=lambda _mod, _input: None, |
| hook=lambda _mod, _input, _output: None, |
| mode="bad", |
| ) |
|
|
|
|
| class _NonIntrusiveTestBase: |
| _PREFIX = "non_intrusive__" |
|
|
| @staticmethod |
| def _assert_captured_contains( |
| captured: dict, expected: list[str], prefix: str = "non_intrusive__" |
| ) -> None: |
| for suffix in expected: |
| key = f"{prefix}{suffix}" |
| assert key in captured, f"missing {key}" |
|
|
| @staticmethod |
| def _wrap_as_outer(inner_cls: type) -> torch.nn.Module: |
| """Wrap an inner module class as OuterModel.model, mimicking typical model nesting.""" |
|
|
| class OuterModel(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.model = inner_cls() |
|
|
| def forward(self, *args, **kwargs): |
| return self.model(*args, **kwargs) |
|
|
| return OuterModel() |
|
|
| @staticmethod |
| def _make_dumper(tmp_path, **overrides) -> "_Dumper": |
| return _make_test_dumper(tmp_path, non_intrusive_mode="all", **overrides) |
|
|
| def _run(self, tmp_path, inner_cls, **dumper_overrides): |
| d = self._make_dumper(tmp_path, **dumper_overrides) |
| model = self._wrap_as_outer(inner_cls) |
| d.register_non_intrusive_dumper(model) |
| x = torch.randn(2, 4) |
| with d.capture_output() as captured: |
| output = model(x) |
| return captured, x, output |
|
|
|
|
| class TestNonIntrusiveDumper(_NonIntrusiveTestBase): |
| """Tests for mode='all' — hooks on every module, non_intrusive__ prefix.""" |
|
|
| def test_basic_inputs_and_outputs(self, tmp_path): |
| class Inner(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(4, 4) |
| self.relu = torch.nn.ReLU() |
|
|
| def forward(self, x): |
| return self.relu(self.linear(x)) |
|
|
| captured, x, output = self._run(tmp_path, Inner) |
|
|
| self._assert_captured_contains( |
| captured, |
| [ |
| "output", |
| "inputs.0", |
| "model.output", |
| "model.inputs.0", |
| "model.linear.output", |
| "model.linear.inputs.0", |
| "model.relu.output", |
| "model.relu.inputs.0", |
| ], |
| ) |
| P = self._PREFIX |
| assert torch.allclose(captured[f"{P}output"]["value"], output) |
|
|
| def test_inputs_dumped_before_forward(self, tmp_path): |
| """Inputs are captured *before* forward(); in-place mutation must not affect them.""" |
|
|
| class Mutator(torch.nn.Module): |
| def forward(self, x): |
| x.fill_(999.0) |
| return x |
|
|
| class Inner(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.mutator = Mutator() |
|
|
| def forward(self, x): |
| return self.mutator(x) |
|
|
| d = self._make_dumper(tmp_path) |
| model = self._wrap_as_outer(Inner) |
| d.register_non_intrusive_dumper(model) |
|
|
| x = torch.randn(2, 4) |
| original_x = x.clone() |
| with d.capture_output() as captured: |
| model(x) |
|
|
| P = self._PREFIX |
| dumped_input = captured[f"{P}model.mutator.inputs.0"]["value"] |
| assert torch.allclose(dumped_input, original_x), ( |
| f"pre-hook should capture inputs before forward mutates them; " |
| f"got {dumped_input} but expected {original_x}" |
| ) |
|
|
| dumped_output = captured[f"{P}model.mutator.output"]["value"] |
| assert ( |
| dumped_output == 999.0 |
| ).all(), "post-hook should capture outputs after forward" |
|
|
| def test_hooks_all_module_levels(self, tmp_path): |
| class Attention(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.qkv_proj = torch.nn.Linear(4, 12) |
| self.o_proj = torch.nn.Linear(4, 4) |
|
|
| def forward(self, x): |
| _qkv = self.qkv_proj(x) |
| return self.o_proj(x) |
|
|
| class Layer(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.self_attn = Attention() |
| self.mlp = torch.nn.Linear(4, 4) |
|
|
| def forward(self, x): |
| x = self.self_attn(x) |
| return self.mlp(x) |
|
|
| class Inner(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.layers = torch.nn.ModuleList([Layer()]) |
|
|
| def forward(self, x): |
| for layer in self.layers: |
| x = layer(x) |
| return x |
|
|
| captured, x, output = self._run(tmp_path, Inner) |
|
|
| self._assert_captured_contains( |
| captured, |
| [ |
| "output", |
| "model.output", |
| "model.layers.0.output", |
| "model.layers.0.self_attn.output", |
| "model.layers.0.self_attn.qkv_proj.output", |
| "model.layers.0.self_attn.o_proj.output", |
| "model.layers.0.mlp.output", |
| "model.layers.0.self_attn.qkv_proj.inputs.0", |
| "model.layers.0.self_attn.o_proj.inputs.0", |
| "model.layers.0.mlp.inputs.0", |
| ], |
| ) |
| P = self._PREFIX |
| assert f"{P}model.layers.output" not in captured |
|
|
| def test_multi_tensor_tuple_output(self, tmp_path): |
| class TupleModule(torch.nn.Module): |
| def forward(self, x): |
| return x, x * 2 |
|
|
| class Inner(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.split = TupleModule() |
| self.linear = torch.nn.Linear(4, 4) |
|
|
| def forward(self, x): |
| a, b = self.split(x) |
| return self.linear(a + b) |
|
|
| captured, x, output = self._run(tmp_path, Inner) |
|
|
| assert "non_intrusive__model.split.output.0" in captured |
| assert "non_intrusive__model.split.output.1" in captured |
| assert torch.allclose( |
| captured["non_intrusive__model.split.output.0"]["value"], x |
| ) |
|
|
| def test_single_tensor_tuple_collapses(self, tmp_path): |
| class SingleTupleModule(torch.nn.Module): |
| def forward(self, x): |
| return (x * 3,) |
|
|
| class Inner(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.wrap = SingleTupleModule() |
|
|
| def forward(self, x): |
| return self.wrap(x)[0] |
|
|
| captured, x, output = self._run(tmp_path, Inner) |
|
|
| assert "non_intrusive__model.wrap.output" in captured |
| assert "non_intrusive__model.wrap.output.0" not in captured |
|
|
| def test_multiple_forward_inputs(self, tmp_path): |
| class TwoInputModule(torch.nn.Module): |
| def forward(self, x, mask): |
| return x * mask |
|
|
| class Inner(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.mul = TwoInputModule() |
|
|
| def forward(self, x): |
| mask = torch.ones_like(x) |
| return self.mul(x, mask) |
|
|
| captured, x, output = self._run(tmp_path, Inner) |
|
|
| assert "non_intrusive__model.mul.inputs.0" in captured |
| assert "non_intrusive__model.mul.inputs.1" in captured |
|
|
| def test_none_output_only_dumps_inputs(self, tmp_path): |
| class NoneModule(torch.nn.Module): |
| def forward(self, x): |
| return None |
|
|
| class Inner(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.sink = NoneModule() |
|
|
| def forward(self, x): |
| self.sink(x) |
| return x |
|
|
| captured, x, output = self._run(tmp_path, Inner) |
|
|
| assert "non_intrusive__model.sink.inputs.0" in captured |
| assert not any( |
| k.startswith("non_intrusive__model.sink.output") for k in captured |
| ) |
|
|
| def test_non_tensor_value_silently_skipped(self, tmp_path): |
| class IntModule(torch.nn.Module): |
| def forward(self, x): |
| return 42 |
|
|
| class Inner(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.const = IntModule() |
|
|
| def forward(self, x): |
| self.const(x) |
| return x |
|
|
| captured, x, output = self._run(tmp_path, Inner) |
|
|
| assert "non_intrusive__model.const.inputs.0" in captured |
| assert not any( |
| k.startswith("non_intrusive__model.const.output") for k in captured |
| ) |
|
|
| def test_root_module_name_no_malformed_dots(self, tmp_path): |
| d = self._make_dumper(tmp_path) |
| model = torch.nn.Linear(4, 4) |
| d.register_non_intrusive_dumper(model) |
|
|
| x = torch.randn(2, 4) |
| with d.capture_output() as captured: |
| model(x) |
|
|
| for key in captured: |
| assert not key.startswith("non_intrusive__."), f"malformed key: {key}" |
| assert ".." not in key, f"double dot in key: {key}" |
|
|
| assert "non_intrusive__output" in captured |
| assert "non_intrusive__inputs.0" in captured |
|
|
| def test_respects_dumper_filter(self, tmp_path): |
| class Inner(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(4, 4) |
| self.relu = torch.nn.ReLU() |
|
|
| def forward(self, x): |
| return self.relu(self.linear(x)) |
|
|
| captured, x, output = self._run( |
| tmp_path, Inner, filter="name == 'non_intrusive__model.linear.output'" |
| ) |
|
|
| assert "non_intrusive__model.linear.output" in captured |
| assert "non_intrusive__model.relu.output" not in captured |
| assert "non_intrusive__model.linear.inputs.0" not in captured |
|
|
| def test_disabled_dumper_no_output(self, tmp_path): |
| class Inner(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(4, 4) |
|
|
| def forward(self, x): |
| return self.linear(x) |
|
|
| d = self._make_dumper(tmp_path) |
| d.configure(enable=False) |
| model = self._wrap_as_outer(Inner) |
| d.register_non_intrusive_dumper(model) |
|
|
| x = torch.randn(2, 4) |
| with d.capture_output() as captured: |
| model(x) |
|
|
| assert len(captured) == 0 |
|
|
|
|
| def _make_forward_batch(): |
| from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode |
|
|
| return ForwardBatch( |
| forward_mode=ForwardMode.DECODE, |
| batch_size=2, |
| input_ids=torch.tensor([10, 20]), |
| req_pool_indices=torch.zeros(2, dtype=torch.long), |
| seq_lens=torch.tensor([5, 6]), |
| out_cache_loc=torch.zeros(2, dtype=torch.long), |
| seq_lens_sum=11, |
| positions=torch.tensor([0, 1]), |
| ) |
|
|
|
|
| class TestNonIntrusiveDumperConfigMode(_NonIntrusiveTestBase): |
| @staticmethod |
| def _build_model() -> torch.nn.Module: |
| class SubLayer(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(4, 4) |
|
|
| def forward(self, forward_batch): |
| return self.linear( |
| forward_batch.input_ids.float().unsqueeze(-1).expand(-1, 4) |
| ) |
|
|
| class Root(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.layer = SubLayer() |
|
|
| def forward(self, forward_batch): |
| return self.layer(forward_batch) |
|
|
| return Root() |
|
|
| def _run(self, tmp_path, mode: str) -> tuple: |
| d = _make_test_dumper(tmp_path, non_intrusive_mode=mode) |
| model = self._build_model() |
| d.register_non_intrusive_dumper(model) |
| forward_batch = _make_forward_batch() |
| with d.capture_output() as captured: |
| model(forward_batch) |
| return captured, forward_batch |
|
|
| def test_off_mode(self, tmp_path): |
| captured, _ = self._run(tmp_path, "off") |
| assert len(captured) == 0 |
|
|
| def test_core_mode(self, tmp_path): |
| captured, fb = self._run(tmp_path, "core") |
|
|
| |
| assert "input_ids" in captured |
| assert "positions" in captured |
| assert "seq_lens" in captured |
| assert torch.equal(captured["input_ids"]["value"], fb.input_ids) |
| assert torch.equal(captured["positions"]["value"], fb.positions) |
| assert torch.equal(captured["seq_lens"]["value"], fb.seq_lens) |
|
|
| |
| assert not any(k.startswith("non_intrusive__") for k in captured) |
|
|
| def test_all_mode(self, tmp_path): |
| captured, fb = self._run(tmp_path, "all") |
|
|
| |
| assert "input_ids" in captured |
| assert "positions" in captured |
| assert "seq_lens" in captured |
| assert torch.equal(captured["input_ids"]["value"], fb.input_ids) |
| assert torch.equal(captured["positions"]["value"], fb.positions) |
| assert torch.equal(captured["seq_lens"]["value"], fb.seq_lens) |
|
|
| |
| for field in ("input_ids", "positions", "seq_lens"): |
| assert not any( |
| k.startswith("non_intrusive__") and k.endswith(field) for k in captured |
| ) |
|
|
| |
| assert not any( |
| k.startswith("non_intrusive__layer.inputs.") and "seq_lens" in k |
| for k in captured |
| ), f"ForwardBatch skipped on sub-module, got: {list(captured.keys())}" |
|
|
| |
| assert "non_intrusive__layer.linear.output" in captured |
| assert "non_intrusive__layer.output" in captured |
|
|
|
|
| class _LayerWithNumber(torch.nn.Module): |
| """Test helper: module with a ``layer_number`` attribute (Megatron style).""" |
|
|
| def __init__(self, layer_number: int): |
| super().__init__() |
| self.layer_number = layer_number |
| self.linear = torch.nn.Linear(4, 4) |
|
|
| def forward(self, x): |
| return self.linear(x) |
|
|
|
|
| class TestNonIntrusiveLayerIdCtx(_NonIntrusiveTestBase): |
| """Tests for automatic layer_id context injection via set_ctx.""" |
|
|
| def test_layer_id_from_layer_number(self, tmp_path): |
| """Megatron PP: layer_number (1-based global) -> layer_id = layer_number - 1.""" |
|
|
| class Inner(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.layers = torch.nn.ModuleList( |
| [_LayerWithNumber(10), _LayerWithNumber(11)] |
| ) |
|
|
| def forward(self, x): |
| for layer in self.layers: |
| x = layer(x) |
| return x |
|
|
| captured, x, output = self._run(tmp_path, Inner) |
|
|
| layer0_key = "non_intrusive__model.layers.0.linear.output" |
| layer1_key = "non_intrusive__model.layers.1.linear.output" |
| assert layer0_key in captured |
| assert layer1_key in captured |
| assert captured[layer0_key]["meta"]["layer_id"] == 9 |
| assert captured[layer1_key]["meta"]["layer_id"] == 10 |
|
|
| root_key = "non_intrusive__output" |
| assert root_key in captured |
| assert "layer_id" not in captured[root_key]["meta"] |
|
|
| def test_layer_id_from_layer_id_attr(self, tmp_path): |
| """SGLang style: module has layer_id attribute directly.""" |
|
|
| class Layer(torch.nn.Module): |
| def __init__(self, layer_id: int): |
| super().__init__() |
| self.layer_id = layer_id |
| self.linear = torch.nn.Linear(4, 4) |
|
|
| def forward(self, x): |
| return self.linear(x) |
|
|
| class Inner(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.layers = torch.nn.ModuleList([Layer(5)]) |
|
|
| def forward(self, x): |
| for layer in self.layers: |
| x = layer(x) |
| return x |
|
|
| captured, x, output = self._run(tmp_path, Inner) |
|
|
| layer_key = "non_intrusive__model.layers.0.linear.output" |
| assert layer_key in captured |
| assert captured[layer_key]["meta"]["layer_id"] == 5 |
|
|
| def test_layer_id_fallback_from_module_name(self, tmp_path): |
| """layers.N modules without layer_number/layer_id -> layer_id from module name.""" |
|
|
| class Inner(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.layers = torch.nn.ModuleList( |
| [torch.nn.Linear(4, 4), torch.nn.Linear(4, 4)] |
| ) |
|
|
| def forward(self, x): |
| for layer in self.layers: |
| x = layer(x) |
| return x |
|
|
| captured, x, output = self._run(tmp_path, Inner) |
|
|
| assert len(captured) > 0 |
| input_keys: list[str] = [ |
| k for k in captured if "model.layers." in k and "inputs" in k |
| ] |
| assert len(input_keys) > 0 |
| for key in input_keys: |
| meta = captured[key]["meta"] |
| assert "layer_id" in meta, f"{key} missing layer_id" |
| if "layers.0" in key: |
| assert meta["layer_id"] == 0 |
| elif "layers.1" in key: |
| assert meta["layer_id"] == 1 |
|
|
| def test_filter_by_layer_id(self, tmp_path): |
| """filter='layer_id == 0' keeps only layer 0 dumps.""" |
|
|
| class Inner(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.layers = torch.nn.ModuleList( |
| [_LayerWithNumber(1), _LayerWithNumber(2)] |
| ) |
|
|
| def forward(self, x): |
| for layer in self.layers: |
| x = layer(x) |
| return x |
|
|
| captured, x, output = self._run(tmp_path, Inner, filter="layer_id == 0") |
|
|
| layer0_keys = [k for k in captured if "layers.0" in k] |
| layer1_keys = [k for k in captured if "layers.1" in k] |
| assert len(layer0_keys) > 0, "layer 0 dumps should be kept" |
| assert len(layer1_keys) == 0, f"layer 1 dumps should be filtered: {layer1_keys}" |
|
|
|
|
| class TestDumperE2E: |
| def test_step_and_non_intrusive_hooks(self, tmp_path): |
| base_url = DEFAULT_URL_FOR_TEST |
| dump_dir = str(tmp_path) |
| env = { |
| **os.environ, |
| "DUMPER_SERVER_PORT": "reuse", |
| } |
| proc = popen_launch_server( |
| "Qwen/Qwen3-0.6B", |
| base_url, |
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, |
| other_args=["--tp", "2", "--max-total-tokens", "128"], |
| env=env, |
| ) |
| try: |
| states = requests.post(f"{base_url}/dumper/get_state", json={}).json() |
| assert len(states) == 2, f"Expected 2 ranks (tp=2), got {len(states)}" |
| for state in states: |
| assert state["config"]["enable"] is False |
| assert state["step"] == 0 |
|
|
| requests.post( |
| f"{base_url}/dumper/configure", |
| json={"enable": True, "dir": dump_dir}, |
| ).raise_for_status() |
|
|
| states = requests.post(f"{base_url}/dumper/get_state", json={}).json() |
| assert len(states) == 2 |
| for rank, state in enumerate(states): |
| assert ( |
| state["config"]["enable"] is True |
| ), f"rank {rank}: enable should be True after configure" |
| assert state["config"]["dir"] == dump_dir |
|
|
| resp = requests.post( |
| f"{base_url}/generate", |
| json={"text": "Hello", "sampling_params": {"max_new_tokens": 8}}, |
| ) |
| assert resp.status_code == 200, f"Generate failed: {resp.text}" |
|
|
| states = requests.post(f"{base_url}/dumper/get_state", json={}).json() |
| assert len(states) == 2 |
| steps = [s["step"] for s in states] |
| for rank, step in enumerate(steps): |
| assert step > 0, f"rank {rank}: step should be > 0, got {step}" |
| assert steps[0] == steps[1], f"step mismatch across ranks: {steps}" |
|
|
| dump_files = list(Path(dump_dir).glob("dump_*/*.pt")) |
| assert len(dump_files) > 0, f"No dump files in {dump_dir}" |
| filenames = {f.name for f in dump_files} |
|
|
| for field in ("input_ids", "positions", "rids"): |
| assert any(f"name={field}" in f for f in filenames), ( |
| f"Missing {field} dump from non-intrusive hooks, " |
| f"got: {sorted(filenames)[:10]}" |
| ) |
|
|
| for rank in range(2): |
| assert any( |
| f"rank={rank}" in f for f in filenames |
| ), f"No dump files for rank {rank}" |
|
|
| sample_file = dump_files[0] |
| loaded = torch.load(sample_file, map_location="cpu", weights_only=False) |
| assert isinstance(loaded, dict), f"Expected dict, got {type(loaded)}" |
| assert ( |
| "value" in loaded and "meta" in loaded |
| ), f"Missing value/meta keys: {loaded.keys()}" |
| assert "name" in loaded["meta"] |
| assert "rank" in loaded["meta"] |
| assert "step" in loaded["meta"] |
|
|
| par = loaded["meta"].get("sglang_parallel_info", {}) |
| expected_keys = [ |
| "tp_rank", |
| "tp_size", |
| "pp_rank", |
| "pp_size", |
| "moe_ep_rank", |
| "moe_ep_size", |
| "moe_tp_rank", |
| "moe_tp_size", |
| "moe_dp_rank", |
| "moe_dp_size", |
| "enable_dp_attention", |
| "attn_tp_rank", |
| "attn_tp_size", |
| "attn_dp_rank", |
| "attn_dp_size", |
| "local_attn_dp_rank", |
| "local_attn_dp_size", |
| "attn_cp_rank", |
| "attn_cp_size", |
| ] |
| for key in expected_keys: |
| assert ( |
| key in par |
| ), f"Missing {key} in sglang_parallel_info, got: {sorted(par)}" |
|
|
| rids_files = [f for f in dump_files if "name=rids" in f.name] |
| rids_loaded = torch.load( |
| rids_files[0], map_location="cpu", weights_only=False |
| ) |
| rids_value = rids_loaded["value"] |
| assert isinstance( |
| rids_value, list |
| ), f"rids should be a list, got {type(rids_value)}" |
| assert len(rids_value) > 0, "rids should be non-empty" |
| assert all( |
| isinstance(r, str) for r in rids_value |
| ), f"each rid should be a str, got {[type(r) for r in rids_value]}" |
| finally: |
| kill_process_tree(proc.pid) |
|
|
|
|
| class TestRegisterForwardHook: |
| @pytest.mark.parametrize("mode", ["hook", "replace_fn"]) |
| def test_handles_removable(self, mode): |
| call_log: list[str] = [] |
|
|
| def pre_hook(_module, _args, _kwargs): |
| call_log.append("pre") |
|
|
| def hook(_module, _input, _output): |
| call_log.append("post") |
|
|
| module = torch.nn.Linear(4, 4) |
| handles = _register_forward_hook_or_replace_fn( |
| module, |
| pre_hook=pre_hook, |
| hook=hook, |
| mode=mode, |
| ) |
|
|
| x = torch.randn(2, 4) |
| if mode == "hook": |
| module(x) |
| else: |
| module.forward(x) |
| assert call_log == ["pre", "post"] |
|
|
| call_log.clear() |
| for h in handles: |
| h.remove() |
|
|
| if mode == "hook": |
| module(x) |
| else: |
| module.forward(x) |
| assert call_log == [] |
|
|
| @pytest.mark.parametrize("mode", ["hook", "replace_fn"]) |
| def test_kwargs_passed_to_pre_hook(self, mode): |
| received: list[tuple] = [] |
|
|
| class KwargsModule(torch.nn.Module): |
| def forward(self, x, *, scale=1.0): |
| return x * scale |
|
|
| def pre_hook(_module, _args, _kwargs): |
| received.append((_args, _kwargs)) |
|
|
| def hook(_module, _input, _output): |
| pass |
|
|
| module = KwargsModule() |
| _register_forward_hook_or_replace_fn( |
| module, |
| pre_hook=pre_hook, |
| hook=hook, |
| mode=mode, |
| ) |
|
|
| x = torch.randn(2, 4) |
| if mode == "hook": |
| module(x, scale=2.0) |
| else: |
| module.forward(x, scale=2.0) |
|
|
| assert len(received) == 1 |
| args, kwargs = received[0] |
| assert len(args) == 1 |
| assert torch.equal(args[0], x) |
| assert kwargs == {"scale": 2.0} |
|
|
| def test_replace_fn_remove_asserts_on_rewrap(self): |
| module = torch.nn.Linear(4, 4) |
| handles = _register_forward_hook_or_replace_fn( |
| module, |
| pre_hook=lambda _m, _a, _kw: None, |
| hook=lambda _m, _i, _o: None, |
| mode="replace_fn", |
| ) |
|
|
| module.forward = lambda *a, **kw: None |
|
|
| with pytest.raises(AssertionError): |
| handles[0].remove() |
|
|
|
|
| class TestPluginCoreFields: |
| def test_sglang_core_fields(self): |
| plugin = _SGLangPlugin() |
| assert plugin.core_fields() == frozenset( |
| {"input_ids", "positions", "seq_lens", "req_pool_indices", "rids"} |
| ) |
|
|
| def test_megatron_core_fields(self): |
| plugin = _MegatronPlugin() |
| assert plugin.core_fields() == frozenset( |
| {"input_ids", "position_ids", "cu_seqlens_q", "cu_seqlens_kv", "qkv_format"} |
| ) |
|
|
|
|
| class TestMegatronConvertValue: |
| @pytest.fixture(autouse=True) |
| def _patch_megatron(self, monkeypatch): |
| class FakePackedSeqParams: |
| def __init__(self, **kwargs): |
| for k, v in kwargs.items(): |
| setattr(self, k, v) |
|
|
| monkeypatch.setattr(_MegatronPlugin, "_available", True) |
| monkeypatch.setattr( |
| _MegatronPlugin, "PackedSeqParams", FakePackedSeqParams, raising=False |
| ) |
| self._FakePackedSeqParams = FakePackedSeqParams |
|
|
| def test_extracts_packed_seq_params(self): |
| plugin = _MegatronPlugin() |
| cu_q = torch.tensor([0, 3, 7]) |
| cu_kv = torch.tensor([0, 3, 7]) |
| value = self._FakePackedSeqParams( |
| cu_seqlens_q=cu_q, cu_seqlens_kv=cu_kv, qkv_format="thd" |
| ) |
|
|
| result = plugin.convert_value(value, skip_forward_batch=False) |
| assert set(result.keys()) == {"cu_seqlens_q", "cu_seqlens_kv", "qkv_format"} |
| assert torch.equal(result["cu_seqlens_q"], cu_q) |
| assert torch.equal(result["cu_seqlens_kv"], cu_kv) |
| assert result["qkv_format"] == "thd" |
|
|
| def test_non_packed_returns_none(self): |
| plugin = _MegatronPlugin() |
| assert plugin.convert_value(torch.randn(4), skip_forward_batch=False) is None |
| assert plugin.convert_value("hello", skip_forward_batch=False) is None |
|
|
|
|
| class TestNonIntrusiveKwargsModel(_NonIntrusiveTestBase): |
| def test_kwargs_core_fields(self, tmp_path): |
| class KwargsModel(torch.nn.Module): |
| def forward(self, *, input_ids, position_ids): |
| return input_ids + position_ids |
|
|
| model = KwargsModel() |
| d = _make_test_dumper(tmp_path, non_intrusive_mode="core") |
| d.register_non_intrusive_dumper(model) |
|
|
| ids = torch.randn(4) |
| pos = torch.randn(4) |
| with d.capture_output() as captured: |
| model(input_ids=ids, position_ids=pos) |
|
|
| assert "input_ids" in captured |
| assert "position_ids" in captured |
| assert torch.equal(captured["input_ids"]["value"], ids) |
| assert torch.equal(captured["position_ids"]["value"], pos) |
|
|
| def test_kwargs_all_mode(self, tmp_path): |
| class KwargsModel(torch.nn.Module): |
| def forward(self, *, input_ids, position_ids, custom_value): |
| return input_ids + position_ids + custom_value |
|
|
| model = KwargsModel() |
| d = _make_test_dumper(tmp_path, non_intrusive_mode="all") |
| d.register_non_intrusive_dumper(model) |
|
|
| ids = torch.randn(4) |
| pos = torch.randn(4) |
| custom = torch.randn(4) |
| with d.capture_output() as captured: |
| model(input_ids=ids, position_ids=pos, custom_value=custom) |
|
|
| assert "input_ids" in captured |
| assert "position_ids" in captured |
|
|
| P = self._PREFIX |
| assert f"{P}inputs.custom_value" in captured |
|
|
| def test_mixed_args_and_kwargs(self, tmp_path): |
| class MixedModel(torch.nn.Module): |
| def forward(self, x, *, input_ids): |
| return x + input_ids |
|
|
| model = MixedModel() |
| d = _make_test_dumper(tmp_path, non_intrusive_mode="all") |
| d.register_non_intrusive_dumper(model) |
|
|
| x = torch.randn(4) |
| ids = torch.randn(4) |
| with d.capture_output() as captured: |
| model(x, input_ids=ids) |
|
|
| assert "input_ids" in captured |
|
|
| P = self._PREFIX |
| assert f"{P}inputs.0" in captured |
|
|
| def test_packed_seq_params_core_fields(self, tmp_path, monkeypatch): |
| class FakePackedSeqParams: |
| def __init__(self, **kwargs): |
| for k, v in kwargs.items(): |
| setattr(self, k, v) |
|
|
| monkeypatch.setattr(_MegatronPlugin, "_available", True) |
| monkeypatch.setattr( |
| _MegatronPlugin, "PackedSeqParams", FakePackedSeqParams, raising=False |
| ) |
|
|
| class MegatronLikeModel(torch.nn.Module): |
| def forward(self, *, input_ids, packed_seq_params): |
| return input_ids |
|
|
| model = MegatronLikeModel() |
| d = _make_test_dumper(tmp_path, non_intrusive_mode="core") |
| d.register_non_intrusive_dumper(model) |
|
|
| ids = torch.randn(4) |
| cu_q = torch.tensor([0, 3, 7]) |
| cu_kv = torch.tensor([0, 3, 7]) |
| psp = FakePackedSeqParams( |
| cu_seqlens_q=cu_q, cu_seqlens_kv=cu_kv, qkv_format="thd" |
| ) |
| with d.capture_output() as captured: |
| model(input_ids=ids, packed_seq_params=psp) |
|
|
| assert "input_ids" in captured |
| assert torch.equal(captured["input_ids"]["value"], ids) |
| assert "cu_seqlens_q" in captured |
| assert torch.equal(captured["cu_seqlens_q"]["value"], cu_q) |
| assert "cu_seqlens_kv" in captured |
| assert torch.equal(captured["cu_seqlens_kv"]["value"], cu_kv) |
| assert "qkv_format" in captured |
| assert captured["qkv_format"]["value"] == "thd" |
|
|
|
|
| class TestDumperDims: |
| def test_dims_in_meta_not_filename(self, tmp_path) -> None: |
| dumper = _make_test_dumper(tmp_path) |
| tensor = torch.randn(4, 8) |
| dumper.dump("hidden", tensor, dims="b h(tp)") |
| dumper.step() |
|
|
| exp_dir = tmp_path / dumper._config.exp_name |
| pt_files = list(exp_dir.glob("*.pt")) |
| assert len(pt_files) == 1 |
|
|
| assert "dims" not in pt_files[0].stem |
|
|
| data = torch.load(pt_files[0], weights_only=False) |
| assert "dims" in data["meta"] |
| assert data["meta"]["dims"] == "b h(tp)" |
|
|
| def test_dims_grad_override(self, tmp_path) -> None: |
| dumper = _Dumper( |
| config=DumperConfig( |
| enable=True, |
| dir=str(tmp_path), |
| enable_grad=True, |
| ) |
| ) |
|
|
| tensor = torch.randn(4, 8, requires_grad=True) |
| dumper.dump("hidden", tensor, dims="b h(tp)", dims_grad="b h(tp:partial)") |
| dumper.step() |
|
|
| tensor.backward(torch.ones_like(tensor)) |
|
|
| exp_dir = tmp_path / dumper._config.exp_name |
| pt_files = sorted(exp_dir.glob("*.pt")) |
| assert len(pt_files) == 2 |
|
|
| value_file = [f for f in pt_files if "grad__" not in f.stem][0] |
| grad_file = [f for f in pt_files if "grad__" in f.stem][0] |
|
|
| value_data = torch.load(value_file, weights_only=False) |
| assert value_data["meta"]["dims"] == "b h(tp)" |
| assert value_data["meta"]["dims_grad"] == "b h(tp:partial)" |
|
|
| grad_data = torch.load(grad_file, weights_only=False) |
| assert grad_data["meta"]["dims"] == "b h(tp:partial)" |
|
|
| def test_dims_grad_inherits(self, tmp_path) -> None: |
| dumper = _Dumper( |
| config=DumperConfig( |
| enable=True, |
| dir=str(tmp_path), |
| enable_grad=True, |
| ) |
| ) |
|
|
| tensor = torch.randn(4, 8, requires_grad=True) |
| dumper.dump("hidden", tensor, dims="b h(tp)") |
| dumper.step() |
|
|
| tensor.backward(torch.ones_like(tensor)) |
|
|
| exp_dir = tmp_path / dumper._config.exp_name |
| grad_file = [f for f in exp_dir.glob("*.pt") if "grad__" in f.stem][0] |
| grad_data = torch.load(grad_file, weights_only=False) |
| assert grad_data["meta"]["dims"] == "b h(tp)" |
|
|
|
|
| class TestCtxDecorator: |
| def test_ctx_dynamic_lambda(self, tmp_path: Path) -> None: |
| d = _make_test_dumper(tmp_path) |
|
|
| class FakeLayer: |
| def __init__(self, layer_id: int) -> None: |
| self.layer_id = layer_id |
|
|
| @d.ctx(lambda self: dict(layer_id=self.layer_id)) |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| d.dump("hidden", x) |
| return x |
|
|
| layer = FakeLayer(layer_id=42) |
| layer.forward(torch.randn(3)) |
|
|
| filenames = _get_filenames(tmp_path) |
| _assert_files(filenames, exist=["layer_id=42"]) |
|
|
| def test_ctx_static_kwargs(self, tmp_path: Path) -> None: |
| d = _make_test_dumper(tmp_path) |
|
|
| @d.ctx(phase="decode") |
| def decode_step(x: torch.Tensor) -> torch.Tensor: |
| d.dump("step_out", x) |
| return x |
|
|
| decode_step(torch.randn(3)) |
|
|
| filenames = _get_filenames(tmp_path) |
| _assert_files(filenames, exist=["phase=decode"]) |
|
|
| def test_ctx_clears_on_exception(self, tmp_path: Path) -> None: |
| d = _make_test_dumper(tmp_path) |
|
|
| @d.ctx(phase="train") |
| def buggy_fn() -> None: |
| raise RuntimeError("boom") |
|
|
| with pytest.raises(RuntimeError, match="boom"): |
| buggy_fn() |
|
|
| assert d._state.global_ctx == {} |
|
|
| def test_ctx_rejects_mixed_args(self) -> None: |
| d = _make_test_dumper("/tmp") |
|
|
| with pytest.raises(ValueError, match="cannot mix"): |
| d.ctx(lambda self: dict(a=1), phase="x") |
|
|
| def test_ctx_rejects_empty_args(self) -> None: |
| d = _make_test_dumper("/tmp") |
|
|
| with pytest.raises(ValueError, match="must provide"): |
| d.ctx() |
|
|
|
|
| class TestRecomputeStatus: |
| def test_disabled_by_default(self, tmp_path: Path) -> None: |
| d = _make_test_dumper(tmp_path) |
| tensor = torch.randn(3, 3) |
| d.dump("test_tensor", tensor) |
|
|
| filenames = _get_filenames(tmp_path) |
| _assert_files(filenames, exist=["recompute_status=disabled"]) |
|
|
| def test_recompute_status_in_embedded_meta(self, tmp_path: Path) -> None: |
| d = _make_test_dumper(tmp_path) |
| tensor = torch.randn(3, 3) |
| d.dump("test_tensor", tensor) |
|
|
| path = _find_dump_file(tmp_path, rank=0, name="test_tensor") |
| raw = _load_dump(path) |
| assert raw["meta"]["recompute_status"] == "disabled" |
|
|
| def test_recompute_status_recompute(self, tmp_path: Path, monkeypatch) -> None: |
| import sglang.srt.debug_utils.dumper as dumper_mod |
|
|
| monkeypatch.setattr( |
| dumper_mod, "_detect_recompute_status", lambda: _RecomputeStatus.RECOMPUTE |
| ) |
|
|
| d = _make_test_dumper(tmp_path) |
| tensor = torch.randn(3, 3) |
| d.dump("test_tensor", tensor) |
|
|
| filenames = _get_filenames(tmp_path) |
| _assert_files(filenames, exist=["recompute_status=recompute"]) |
|
|
| path = _find_dump_file(tmp_path, rank=0, name="test_tensor") |
| raw = _load_dump(path) |
| assert raw["meta"]["recompute_status"] == "recompute" |
| assert raw["meta"]["recompute_pseudo_rank"] == 1 |
| assert raw["meta"]["recompute_pseudo_size"] == 2 |
|
|
| def test_recompute_status_original(self, tmp_path: Path, monkeypatch) -> None: |
| import sglang.srt.debug_utils.dumper as dumper_mod |
|
|
| monkeypatch.setattr( |
| dumper_mod, |
| "_detect_recompute_status", |
| lambda: _RecomputeStatus.ORIGINAL, |
| ) |
|
|
| d = _make_test_dumper(tmp_path) |
| tensor = torch.randn(3, 3) |
| d.dump("test_tensor", tensor) |
|
|
| filenames = _get_filenames(tmp_path) |
| _assert_files(filenames, exist=["recompute_status=original"]) |
|
|
| path = _find_dump_file(tmp_path, rank=0, name="test_tensor") |
| raw = _load_dump(path) |
| assert raw["meta"]["recompute_status"] == "original" |
| assert raw["meta"]["recompute_pseudo_rank"] == 0 |
| assert raw["meta"]["recompute_pseudo_size"] == 2 |
|
|
| def test_disabled_no_recompute_pseudo_fields(self, tmp_path: Path) -> None: |
| d = _make_test_dumper(tmp_path) |
| tensor = torch.randn(3, 3) |
| d.dump("test_tensor", tensor) |
|
|
| path = _find_dump_file(tmp_path, rank=0, name="test_tensor") |
| raw = _load_dump(path) |
| assert "recompute_pseudo_rank" not in raw["meta"] |
| assert "recompute_pseudo_size" not in raw["meta"] |
|
|
| def test_grad_hook_has_no_recompute_status(self, tmp_path: Path) -> None: |
| d = _make_test_dumper(tmp_path, enable_grad=True) |
| x = torch.randn(3, 3, requires_grad=True) |
| y = (x * 2).sum() |
|
|
| d.dump("test_tensor", x) |
| y.backward() |
|
|
| grad_files = [f for f in _get_filenames(tmp_path) if "grad__test_tensor" in f] |
| assert len(grad_files) == 1 |
| assert "recompute_status" not in grad_files[0] |
|
|
| def test_non_intrusive_hooks_have_recompute_status(self, tmp_path: Path) -> None: |
| class Simple(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(4, 4) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.linear(x) |
|
|
| model = Simple() |
| d = _make_test_dumper(tmp_path, non_intrusive_mode="all") |
| d.register_non_intrusive_dumper(model) |
|
|
| with d.capture_output() as captured: |
| model(torch.randn(2, 4)) |
|
|
| for key, data in captured.items(): |
| assert ( |
| "recompute_status" in data["meta"] |
| ), f"missing recompute_status in {key}" |
| assert data["meta"]["recompute_status"] == "disabled" |
|
|
| def test_detect_recompute_status_default(self) -> None: |
| assert _detect_recompute_status() == _RecomputeStatus.DISABLED |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(pytest.main([__file__])) |
|
|