File size: 4,582 Bytes
2a3c1ea | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | #!/usr/bin/env python
"""Verify the benign MessagePack/model checkpoint deserialization PoC."""
from __future__ import annotations
import argparse
import hashlib
import importlib.metadata as metadata
import json
import sys
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
import msgpack
import msgpack_numpy
import numpy as np
import ray
from ray.rllib.utils.checkpoints import Checkpointable
ARTIFACT_NAME = "state.msgpack"
MARKER_NAME = "MSG_PACK_NUMPY_MARKER.txt"
class DemoCheckpointable(Checkpointable):
def __init__(self) -> None:
self.restored_state: Optional[Dict[str, Any]] = None
def get_state(self, components=None, *, not_components=None, **kwargs):
return {}
def set_state(self, state):
self.restored_state = state
def get_ctor_args_and_kwargs(self) -> Tuple[Tuple, Dict[str, Any]]:
return (), {}
def sha256(path: Path) -> str:
digest = hashlib.sha256()
with path.open("rb") as handle:
for chunk in iter(lambda: handle.read(1024 * 1024), b""):
digest.update(chunk)
return digest.hexdigest()
def package_version(name: str) -> str:
try:
return metadata.version(name)
except metadata.PackageNotFoundError:
return "not installed"
def plain_msgpack_check(artifact: Path, marker: Path) -> Dict[str, Any]:
marker.unlink(missing_ok=True)
with artifact.open("rb") as handle:
data = msgpack.load(handle, raw=False, strict_map_key=False)
return {
"plain_msgpack_type": type(data).__name__,
"plain_msgpack_keys": sorted(str(k) for k in data.keys()),
"marker_created": marker.exists(),
}
def rllib_restore_check(checkpoint_dir: Path, marker: Path) -> Dict[str, Any]:
marker.unlink(missing_ok=True)
demo = DemoCheckpointable()
demo.restore_from_path(checkpoint_dir)
restored = demo.restored_state or {}
marker_text = marker.read_text(encoding="utf-8") if marker.exists() else None
object_value = restored.get("object_array")
return {
"restored_keys": sorted(restored.keys()),
"object_array_type": type(object_value).__name__,
"object_array_repr": repr(object_value),
"marker_created": marker.exists(),
"marker_text": marker_text,
}
def direct_msgpack_numpy_check(artifact: Path, marker: Path) -> Dict[str, Any]:
marker.unlink(missing_ok=True)
with artifact.open("rb") as handle:
data = msgpack_numpy.load(handle, raw=False, strict_map_key=False)
marker_text = marker.read_text(encoding="utf-8") if marker.exists() else None
return {
"msgpack_numpy_type": type(data).__name__,
"msgpack_numpy_keys": sorted(data.keys()),
"marker_created": marker.exists(),
"marker_text": marker_text,
}
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--artifact",
type=Path,
default=Path(__file__).resolve().parent / ARTIFACT_NAME,
)
parser.add_argument(
"--results",
type=Path,
default=Path(__file__).resolve().parent / "results.json",
)
args = parser.parse_args()
artifact = args.artifact.resolve()
checkpoint_dir = artifact.parent
marker = Path.cwd() / MARKER_NAME
if not artifact.exists():
raise FileNotFoundError(artifact)
results = {
"artifact": str(artifact),
"artifact_sha256": sha256(artifact),
"artifact_size_bytes": artifact.stat().st_size,
"versions": {
"python": sys.version,
"ray": ray.__version__,
"msgpack": package_version("msgpack"),
"msgpack-numpy": package_version("msgpack-numpy"),
"numpy": np.__version__,
"modelscan": package_version("modelscan"),
},
"plain_msgpack_check": plain_msgpack_check(artifact, marker),
"direct_msgpack_numpy_check": direct_msgpack_numpy_check(artifact, marker),
"ray_rllib_restore_check": rllib_restore_check(checkpoint_dir, marker),
"limitation": "This is ACE via msgpack-numpy object-array pickle decoding during RLlib msgpack checkpoint restore; it is not a native parser memory-corruption issue.",
}
args.results.write_text(json.dumps(results, indent=2, default=str) + "\n", encoding="utf-8")
print(json.dumps(results, indent=2, default=str))
if not results["ray_rllib_restore_check"]["marker_created"]:
raise SystemExit("marker was not created through Ray RLlib restore path")
if __name__ == "__main__":
main()
|