#!/usr/bin/env python """Verify the benign MessagePack/model checkpoint deserialization PoC.""" from __future__ import annotations import argparse import hashlib import importlib.metadata as metadata import json import sys from pathlib import Path from typing import Any, Dict, Optional, Tuple import msgpack import msgpack_numpy import numpy as np import ray from ray.rllib.utils.checkpoints import Checkpointable ARTIFACT_NAME = "state.msgpack" MARKER_NAME = "MSG_PACK_NUMPY_MARKER.txt" class DemoCheckpointable(Checkpointable): def __init__(self) -> None: self.restored_state: Optional[Dict[str, Any]] = None def get_state(self, components=None, *, not_components=None, **kwargs): return {} def set_state(self, state): self.restored_state = state def get_ctor_args_and_kwargs(self) -> Tuple[Tuple, Dict[str, Any]]: return (), {} def sha256(path: Path) -> str: digest = hashlib.sha256() with path.open("rb") as handle: for chunk in iter(lambda: handle.read(1024 * 1024), b""): digest.update(chunk) return digest.hexdigest() def package_version(name: str) -> str: try: return metadata.version(name) except metadata.PackageNotFoundError: return "not installed" def plain_msgpack_check(artifact: Path, marker: Path) -> Dict[str, Any]: marker.unlink(missing_ok=True) with artifact.open("rb") as handle: data = msgpack.load(handle, raw=False, strict_map_key=False) return { "plain_msgpack_type": type(data).__name__, "plain_msgpack_keys": sorted(str(k) for k in data.keys()), "marker_created": marker.exists(), } def rllib_restore_check(checkpoint_dir: Path, marker: Path) -> Dict[str, Any]: marker.unlink(missing_ok=True) demo = DemoCheckpointable() demo.restore_from_path(checkpoint_dir) restored = demo.restored_state or {} marker_text = marker.read_text(encoding="utf-8") if marker.exists() else None object_value = restored.get("object_array") return { "restored_keys": sorted(restored.keys()), "object_array_type": type(object_value).__name__, "object_array_repr": repr(object_value), "marker_created": marker.exists(), "marker_text": marker_text, } def direct_msgpack_numpy_check(artifact: Path, marker: Path) -> Dict[str, Any]: marker.unlink(missing_ok=True) with artifact.open("rb") as handle: data = msgpack_numpy.load(handle, raw=False, strict_map_key=False) marker_text = marker.read_text(encoding="utf-8") if marker.exists() else None return { "msgpack_numpy_type": type(data).__name__, "msgpack_numpy_keys": sorted(data.keys()), "marker_created": marker.exists(), "marker_text": marker_text, } def main() -> None: parser = argparse.ArgumentParser() parser.add_argument( "--artifact", type=Path, default=Path(__file__).resolve().parent / ARTIFACT_NAME, ) parser.add_argument( "--results", type=Path, default=Path(__file__).resolve().parent / "results.json", ) args = parser.parse_args() artifact = args.artifact.resolve() checkpoint_dir = artifact.parent marker = Path.cwd() / MARKER_NAME if not artifact.exists(): raise FileNotFoundError(artifact) results = { "artifact": str(artifact), "artifact_sha256": sha256(artifact), "artifact_size_bytes": artifact.stat().st_size, "versions": { "python": sys.version, "ray": ray.__version__, "msgpack": package_version("msgpack"), "msgpack-numpy": package_version("msgpack-numpy"), "numpy": np.__version__, "modelscan": package_version("modelscan"), }, "plain_msgpack_check": plain_msgpack_check(artifact, marker), "direct_msgpack_numpy_check": direct_msgpack_numpy_check(artifact, marker), "ray_rllib_restore_check": rllib_restore_check(checkpoint_dir, marker), "limitation": "This is ACE via msgpack-numpy object-array pickle decoding during RLlib msgpack checkpoint restore; it is not a native parser memory-corruption issue.", } args.results.write_text(json.dumps(results, indent=2, default=str) + "\n", encoding="utf-8") print(json.dumps(results, indent=2, default=str)) if not results["ray_rllib_restore_check"]["marker_created"]: raise SystemExit("marker was not created through Ray RLlib restore path") if __name__ == "__main__": main()