pragnyanramtha's picture
Upload folder using huggingface_hub
2a3c1ea verified
#!/usr/bin/env python
"""Verify the benign MessagePack/model checkpoint deserialization PoC."""
from __future__ import annotations
import argparse
import hashlib
import importlib.metadata as metadata
import json
import sys
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
import msgpack
import msgpack_numpy
import numpy as np
import ray
from ray.rllib.utils.checkpoints import Checkpointable
ARTIFACT_NAME = "state.msgpack"
MARKER_NAME = "MSG_PACK_NUMPY_MARKER.txt"
class DemoCheckpointable(Checkpointable):
def __init__(self) -> None:
self.restored_state: Optional[Dict[str, Any]] = None
def get_state(self, components=None, *, not_components=None, **kwargs):
return {}
def set_state(self, state):
self.restored_state = state
def get_ctor_args_and_kwargs(self) -> Tuple[Tuple, Dict[str, Any]]:
return (), {}
def sha256(path: Path) -> str:
digest = hashlib.sha256()
with path.open("rb") as handle:
for chunk in iter(lambda: handle.read(1024 * 1024), b""):
digest.update(chunk)
return digest.hexdigest()
def package_version(name: str) -> str:
try:
return metadata.version(name)
except metadata.PackageNotFoundError:
return "not installed"
def plain_msgpack_check(artifact: Path, marker: Path) -> Dict[str, Any]:
marker.unlink(missing_ok=True)
with artifact.open("rb") as handle:
data = msgpack.load(handle, raw=False, strict_map_key=False)
return {
"plain_msgpack_type": type(data).__name__,
"plain_msgpack_keys": sorted(str(k) for k in data.keys()),
"marker_created": marker.exists(),
}
def rllib_restore_check(checkpoint_dir: Path, marker: Path) -> Dict[str, Any]:
marker.unlink(missing_ok=True)
demo = DemoCheckpointable()
demo.restore_from_path(checkpoint_dir)
restored = demo.restored_state or {}
marker_text = marker.read_text(encoding="utf-8") if marker.exists() else None
object_value = restored.get("object_array")
return {
"restored_keys": sorted(restored.keys()),
"object_array_type": type(object_value).__name__,
"object_array_repr": repr(object_value),
"marker_created": marker.exists(),
"marker_text": marker_text,
}
def direct_msgpack_numpy_check(artifact: Path, marker: Path) -> Dict[str, Any]:
marker.unlink(missing_ok=True)
with artifact.open("rb") as handle:
data = msgpack_numpy.load(handle, raw=False, strict_map_key=False)
marker_text = marker.read_text(encoding="utf-8") if marker.exists() else None
return {
"msgpack_numpy_type": type(data).__name__,
"msgpack_numpy_keys": sorted(data.keys()),
"marker_created": marker.exists(),
"marker_text": marker_text,
}
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--artifact",
type=Path,
default=Path(__file__).resolve().parent / ARTIFACT_NAME,
)
parser.add_argument(
"--results",
type=Path,
default=Path(__file__).resolve().parent / "results.json",
)
args = parser.parse_args()
artifact = args.artifact.resolve()
checkpoint_dir = artifact.parent
marker = Path.cwd() / MARKER_NAME
if not artifact.exists():
raise FileNotFoundError(artifact)
results = {
"artifact": str(artifact),
"artifact_sha256": sha256(artifact),
"artifact_size_bytes": artifact.stat().st_size,
"versions": {
"python": sys.version,
"ray": ray.__version__,
"msgpack": package_version("msgpack"),
"msgpack-numpy": package_version("msgpack-numpy"),
"numpy": np.__version__,
"modelscan": package_version("modelscan"),
},
"plain_msgpack_check": plain_msgpack_check(artifact, marker),
"direct_msgpack_numpy_check": direct_msgpack_numpy_check(artifact, marker),
"ray_rllib_restore_check": rllib_restore_check(checkpoint_dir, marker),
"limitation": "This is ACE via msgpack-numpy object-array pickle decoding during RLlib msgpack checkpoint restore; it is not a native parser memory-corruption issue.",
}
args.results.write_text(json.dumps(results, indent=2, default=str) + "\n", encoding="utf-8")
print(json.dumps(results, indent=2, default=str))
if not results["ray_rllib_restore_check"]["marker_created"]:
raise SystemExit("marker was not created through Ray RLlib restore path")
if __name__ == "__main__":
main()