File size: 6,702 Bytes
8c738c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
#!/usr/bin/env python
"""
PT bundler: Bundle multiple .pt files into a single archive without modifying originals.
Supports two formats:
1) ZIP archive (recommended) – exact bytes of each .pt preserved.
2) PT container – a single .pt (pickle) file containing a dict {relative_path: bytes}.
CLI examples (PowerShell):
# Create ZIP bundle from current repo
python tools/pt_bundle.py zip --source . --out models_bundle.zip
# Create PT container bundle
python tools/pt_bundle.py pt --source . --out models_multi.pt
# List contents
python tools/pt_bundle.py list --bundle models_bundle.zip
python tools/pt_bundle.py list --bundle models_multi.pt
# Extract a single model from bundle to a path
python tools/pt_bundle.py extract --bundle models_multi.pt --member path/to/model.pt --out C:/tmp/model.pt
"""
from __future__ import annotations
import argparse
import io
import os
import sys
from pathlib import Path
from typing import Iterable, List
try:
import torch # Only needed for PT container
except Exception: # pragma: no cover - optional for ZIP-only usage
torch = None # type: ignore
import zipfile
def find_pt_files(source: Path, include: Iterable[str] | None = None, exclude: Iterable[str] | None = None) -> List[Path]:
include = list(include or ["*.pt"]) # default include all .pt
exclude = list(exclude or [])
files: List[Path] = []
for p in source.rglob("*.pt"):
rel = p.relative_to(source)
rel_str = str(rel).replace("\\", "/")
if include and not any(Path(rel_str).match(pat) for pat in include):
continue
if exclude and any(Path(rel_str).match(pat) for pat in exclude):
continue
files.append(p)
return files
def create_zip_bundle(source: Path, out_path: Path, includes: Iterable[str] | None = None, excludes: Iterable[str] | None = None) -> int:
files = find_pt_files(source, includes, excludes)
out_path.parent.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(out_path, mode="w", compression=zipfile.ZIP_DEFLATED) as zf:
for f in files:
zf.write(f, f.relative_to(source))
return len(files)
def create_pt_container(source: Path, out_path: Path, includes: Iterable[str] | None = None, excludes: Iterable[str] | None = None) -> int:
if torch is None:
raise RuntimeError("torch is required for PT container mode. Install torch and retry.")
files = find_pt_files(source, includes, excludes)
payload = {}
for f in files:
rel = str(f.relative_to(source)).replace("\\", "/")
with open(f, "rb") as fh:
payload[rel] = fh.read() # store exact bytes (no mutation)
out_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(payload, out_path)
return len(files)
def list_bundle(bundle: Path) -> List[str]:
if bundle.suffix.lower() == ".zip":
with zipfile.ZipFile(bundle, "r") as zf:
return [i.filename for i in zf.infolist() if not i.is_dir()]
else:
if torch is None:
raise RuntimeError("torch is required to list PT container contents.")
data = torch.load(bundle, map_location="cpu")
if isinstance(data, dict):
return sorted(map(str, data.keys()))
raise ValueError("Unsupported PT container format: expected dict mapping.")
def extract_member(bundle: Path, member: str, out_path: Path) -> None:
if bundle.suffix.lower() == ".zip":
with zipfile.ZipFile(bundle, "r") as zf:
with zf.open(member, "r") as fh, open(out_path, "wb") as out:
out.write(fh.read())
else:
if torch is None:
raise RuntimeError("torch is required to extract from PT container.")
data = torch.load(bundle, map_location="cpu")
if not isinstance(data, dict):
raise ValueError("Unsupported PT container format: expected dict mapping.")
if member not in data:
raise FileNotFoundError(f"member not found in container: {member}")
out_path.parent.mkdir(parents=True, exist_ok=True)
with open(out_path, "wb") as fh:
fh.write(data[member])
def main(argv: List[str] | None = None) -> int:
parser = argparse.ArgumentParser(description="Bundle multiple .pt files without modifying originals.")
sub = parser.add_subparsers(dest="cmd", required=True)
p_zip = sub.add_parser("zip", help="Create a ZIP archive of .pt files.")
p_zip.add_argument("--source", default=".", help="Root directory to scan for .pt files.")
p_zip.add_argument("--out", required=True, help="Output ZIP path.")
p_zip.add_argument("--include", nargs="*", default=["*.pt"], help="Glob patterns to include.")
p_zip.add_argument("--exclude", nargs="*", default=[], help="Glob patterns to exclude.")
p_pt = sub.add_parser("pt", help="Create a single .pt container (dict of bytes).")
p_pt.add_argument("--source", default=".", help="Root directory to scan for .pt files.")
p_pt.add_argument("--out", required=True, help="Output PT path (e.g., models_multi.pt).")
p_pt.add_argument("--include", nargs="*", default=["*.pt"], help="Glob patterns to include.")
p_pt.add_argument("--exclude", nargs="*", default=[], help="Glob patterns to exclude.")
p_list = sub.add_parser("list", help="List contents of a bundle (ZIP or PT container).")
p_list.add_argument("--bundle", required=True, help="Path to models_bundle.zip or models_multi.pt.")
p_ext = sub.add_parser("extract", help="Extract a single member from the bundle.")
p_ext.add_argument("--bundle", required=True, help="Bundle path (ZIP or PT container).")
p_ext.add_argument("--member", required=True, help="Member path inside the bundle.")
p_ext.add_argument("--out", required=True, help="Destination file path to write.")
args = parser.parse_args(argv)
if args.cmd == "zip":
count = create_zip_bundle(Path(args.source), Path(args.out), args.include, args.exclude)
print(f"ZIP bundle written: {args.out} ({count} files)")
return 0
if args.cmd == "pt":
count = create_pt_container(Path(args.source), Path(args.out), args.include, args.exclude)
print(f"PT container written: {args.out} ({count} files)")
return 0
if args.cmd == "list":
items = list_bundle(Path(args.bundle))
for it in items:
print(it)
return 0
if args.cmd == "extract":
extract_member(Path(args.bundle), args.member, Path(args.out))
print(f"Extracted {args.member} -> {args.out}")
return 0
parser.print_help()
return 1
if __name__ == "__main__":
raise SystemExit(main())
|