| |
| """Verify the benign MessagePack/model checkpoint deserialization PoC.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import hashlib |
| import importlib.metadata as metadata |
| import json |
| import sys |
| from pathlib import Path |
| from typing import Any, Dict, Optional, Tuple |
|
|
| import msgpack |
| import msgpack_numpy |
| import numpy as np |
| import ray |
| from ray.rllib.utils.checkpoints import Checkpointable |
|
|
|
|
| ARTIFACT_NAME = "state.msgpack" |
| MARKER_NAME = "MSG_PACK_NUMPY_MARKER.txt" |
|
|
|
|
| class DemoCheckpointable(Checkpointable): |
| def __init__(self) -> None: |
| self.restored_state: Optional[Dict[str, Any]] = None |
|
|
| def get_state(self, components=None, *, not_components=None, **kwargs): |
| return {} |
|
|
| def set_state(self, state): |
| self.restored_state = state |
|
|
| def get_ctor_args_and_kwargs(self) -> Tuple[Tuple, Dict[str, Any]]: |
| return (), {} |
|
|
|
|
| def sha256(path: Path) -> str: |
| digest = hashlib.sha256() |
| with path.open("rb") as handle: |
| for chunk in iter(lambda: handle.read(1024 * 1024), b""): |
| digest.update(chunk) |
| return digest.hexdigest() |
|
|
|
|
| def package_version(name: str) -> str: |
| try: |
| return metadata.version(name) |
| except metadata.PackageNotFoundError: |
| return "not installed" |
|
|
|
|
| def plain_msgpack_check(artifact: Path, marker: Path) -> Dict[str, Any]: |
| marker.unlink(missing_ok=True) |
| with artifact.open("rb") as handle: |
| data = msgpack.load(handle, raw=False, strict_map_key=False) |
| return { |
| "plain_msgpack_type": type(data).__name__, |
| "plain_msgpack_keys": sorted(str(k) for k in data.keys()), |
| "marker_created": marker.exists(), |
| } |
|
|
|
|
| def rllib_restore_check(checkpoint_dir: Path, marker: Path) -> Dict[str, Any]: |
| marker.unlink(missing_ok=True) |
| demo = DemoCheckpointable() |
| demo.restore_from_path(checkpoint_dir) |
| restored = demo.restored_state or {} |
| marker_text = marker.read_text(encoding="utf-8") if marker.exists() else None |
| object_value = restored.get("object_array") |
| return { |
| "restored_keys": sorted(restored.keys()), |
| "object_array_type": type(object_value).__name__, |
| "object_array_repr": repr(object_value), |
| "marker_created": marker.exists(), |
| "marker_text": marker_text, |
| } |
|
|
|
|
| def direct_msgpack_numpy_check(artifact: Path, marker: Path) -> Dict[str, Any]: |
| marker.unlink(missing_ok=True) |
| with artifact.open("rb") as handle: |
| data = msgpack_numpy.load(handle, raw=False, strict_map_key=False) |
| marker_text = marker.read_text(encoding="utf-8") if marker.exists() else None |
| return { |
| "msgpack_numpy_type": type(data).__name__, |
| "msgpack_numpy_keys": sorted(data.keys()), |
| "marker_created": marker.exists(), |
| "marker_text": marker_text, |
| } |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--artifact", |
| type=Path, |
| default=Path(__file__).resolve().parent / ARTIFACT_NAME, |
| ) |
| parser.add_argument( |
| "--results", |
| type=Path, |
| default=Path(__file__).resolve().parent / "results.json", |
| ) |
| args = parser.parse_args() |
|
|
| artifact = args.artifact.resolve() |
| checkpoint_dir = artifact.parent |
| marker = Path.cwd() / MARKER_NAME |
|
|
| if not artifact.exists(): |
| raise FileNotFoundError(artifact) |
|
|
| results = { |
| "artifact": str(artifact), |
| "artifact_sha256": sha256(artifact), |
| "artifact_size_bytes": artifact.stat().st_size, |
| "versions": { |
| "python": sys.version, |
| "ray": ray.__version__, |
| "msgpack": package_version("msgpack"), |
| "msgpack-numpy": package_version("msgpack-numpy"), |
| "numpy": np.__version__, |
| "modelscan": package_version("modelscan"), |
| }, |
| "plain_msgpack_check": plain_msgpack_check(artifact, marker), |
| "direct_msgpack_numpy_check": direct_msgpack_numpy_check(artifact, marker), |
| "ray_rllib_restore_check": rllib_restore_check(checkpoint_dir, marker), |
| "limitation": "This is ACE via msgpack-numpy object-array pickle decoding during RLlib msgpack checkpoint restore; it is not a native parser memory-corruption issue.", |
| } |
|
|
| args.results.write_text(json.dumps(results, indent=2, default=str) + "\n", encoding="utf-8") |
| print(json.dumps(results, indent=2, default=str)) |
|
|
| if not results["ray_rllib_restore_check"]["marker_created"]: |
| raise SystemExit("marker was not created through Ray RLlib restore path") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|