File size: 4,582 Bytes
2a3c1ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#!/usr/bin/env python
"""Verify the benign MessagePack/model checkpoint deserialization PoC."""

from __future__ import annotations

import argparse
import hashlib
import importlib.metadata as metadata
import json
import sys
from pathlib import Path
from typing import Any, Dict, Optional, Tuple

import msgpack
import msgpack_numpy
import numpy as np
import ray
from ray.rllib.utils.checkpoints import Checkpointable


ARTIFACT_NAME = "state.msgpack"
MARKER_NAME = "MSG_PACK_NUMPY_MARKER.txt"


class DemoCheckpointable(Checkpointable):
    def __init__(self) -> None:
        self.restored_state: Optional[Dict[str, Any]] = None

    def get_state(self, components=None, *, not_components=None, **kwargs):
        return {}

    def set_state(self, state):
        self.restored_state = state

    def get_ctor_args_and_kwargs(self) -> Tuple[Tuple, Dict[str, Any]]:
        return (), {}


def sha256(path: Path) -> str:
    digest = hashlib.sha256()
    with path.open("rb") as handle:
        for chunk in iter(lambda: handle.read(1024 * 1024), b""):
            digest.update(chunk)
    return digest.hexdigest()


def package_version(name: str) -> str:
    try:
        return metadata.version(name)
    except metadata.PackageNotFoundError:
        return "not installed"


def plain_msgpack_check(artifact: Path, marker: Path) -> Dict[str, Any]:
    marker.unlink(missing_ok=True)
    with artifact.open("rb") as handle:
        data = msgpack.load(handle, raw=False, strict_map_key=False)
    return {
        "plain_msgpack_type": type(data).__name__,
        "plain_msgpack_keys": sorted(str(k) for k in data.keys()),
        "marker_created": marker.exists(),
    }


def rllib_restore_check(checkpoint_dir: Path, marker: Path) -> Dict[str, Any]:
    marker.unlink(missing_ok=True)
    demo = DemoCheckpointable()
    demo.restore_from_path(checkpoint_dir)
    restored = demo.restored_state or {}
    marker_text = marker.read_text(encoding="utf-8") if marker.exists() else None
    object_value = restored.get("object_array")
    return {
        "restored_keys": sorted(restored.keys()),
        "object_array_type": type(object_value).__name__,
        "object_array_repr": repr(object_value),
        "marker_created": marker.exists(),
        "marker_text": marker_text,
    }


def direct_msgpack_numpy_check(artifact: Path, marker: Path) -> Dict[str, Any]:
    marker.unlink(missing_ok=True)
    with artifact.open("rb") as handle:
        data = msgpack_numpy.load(handle, raw=False, strict_map_key=False)
    marker_text = marker.read_text(encoding="utf-8") if marker.exists() else None
    return {
        "msgpack_numpy_type": type(data).__name__,
        "msgpack_numpy_keys": sorted(data.keys()),
        "marker_created": marker.exists(),
        "marker_text": marker_text,
    }


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--artifact",
        type=Path,
        default=Path(__file__).resolve().parent / ARTIFACT_NAME,
    )
    parser.add_argument(
        "--results",
        type=Path,
        default=Path(__file__).resolve().parent / "results.json",
    )
    args = parser.parse_args()

    artifact = args.artifact.resolve()
    checkpoint_dir = artifact.parent
    marker = Path.cwd() / MARKER_NAME

    if not artifact.exists():
        raise FileNotFoundError(artifact)

    results = {
        "artifact": str(artifact),
        "artifact_sha256": sha256(artifact),
        "artifact_size_bytes": artifact.stat().st_size,
        "versions": {
            "python": sys.version,
            "ray": ray.__version__,
            "msgpack": package_version("msgpack"),
            "msgpack-numpy": package_version("msgpack-numpy"),
            "numpy": np.__version__,
            "modelscan": package_version("modelscan"),
        },
        "plain_msgpack_check": plain_msgpack_check(artifact, marker),
        "direct_msgpack_numpy_check": direct_msgpack_numpy_check(artifact, marker),
        "ray_rllib_restore_check": rllib_restore_check(checkpoint_dir, marker),
        "limitation": "This is ACE via msgpack-numpy object-array pickle decoding during RLlib msgpack checkpoint restore; it is not a native parser memory-corruption issue.",
    }

    args.results.write_text(json.dumps(results, indent=2, default=str) + "\n", encoding="utf-8")
    print(json.dumps(results, indent=2, default=str))

    if not results["ray_rllib_restore_check"]["marker_created"]:
        raise SystemExit("marker was not created through Ray RLlib restore path")


if __name__ == "__main__":
    main()