ultralytics
Eval Results
File size: 6,702 Bytes
8c738c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#!/usr/bin/env python
"""
PT bundler: Bundle multiple .pt files into a single archive without modifying originals.

Supports two formats:
  1) ZIP archive (recommended) – exact bytes of each .pt preserved.
  2) PT container – a single .pt (pickle) file containing a dict {relative_path: bytes}.

CLI examples (PowerShell):
  # Create ZIP bundle from current repo
  python tools/pt_bundle.py zip --source . --out models_bundle.zip

  # Create PT container bundle
  python tools/pt_bundle.py pt --source . --out models_multi.pt

  # List contents
  python tools/pt_bundle.py list --bundle models_bundle.zip
  python tools/pt_bundle.py list --bundle models_multi.pt

  # Extract a single model from bundle to a path
  python tools/pt_bundle.py extract --bundle models_multi.pt --member path/to/model.pt --out C:/tmp/model.pt
"""

from __future__ import annotations

import argparse
import io
import os
import sys
from pathlib import Path
from typing import Iterable, List

try:
    import torch  # Only needed for PT container
except Exception:  # pragma: no cover - optional for ZIP-only usage
    torch = None  # type: ignore

import zipfile


def find_pt_files(source: Path, include: Iterable[str] | None = None, exclude: Iterable[str] | None = None) -> List[Path]:
    include = list(include or ["*.pt"])  # default include all .pt
    exclude = list(exclude or [])
    files: List[Path] = []
    for p in source.rglob("*.pt"):
        rel = p.relative_to(source)
        rel_str = str(rel).replace("\\", "/")
        if include and not any(Path(rel_str).match(pat) for pat in include):
            continue
        if exclude and any(Path(rel_str).match(pat) for pat in exclude):
            continue
        files.append(p)
    return files


def create_zip_bundle(source: Path, out_path: Path, includes: Iterable[str] | None = None, excludes: Iterable[str] | None = None) -> int:
    files = find_pt_files(source, includes, excludes)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    with zipfile.ZipFile(out_path, mode="w", compression=zipfile.ZIP_DEFLATED) as zf:
        for f in files:
            zf.write(f, f.relative_to(source))
    return len(files)


def create_pt_container(source: Path, out_path: Path, includes: Iterable[str] | None = None, excludes: Iterable[str] | None = None) -> int:
    if torch is None:
        raise RuntimeError("torch is required for PT container mode. Install torch and retry.")
    files = find_pt_files(source, includes, excludes)
    payload = {}
    for f in files:
        rel = str(f.relative_to(source)).replace("\\", "/")
        with open(f, "rb") as fh:
            payload[rel] = fh.read()  # store exact bytes (no mutation)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    torch.save(payload, out_path)
    return len(files)


def list_bundle(bundle: Path) -> List[str]:
    if bundle.suffix.lower() == ".zip":
        with zipfile.ZipFile(bundle, "r") as zf:
            return [i.filename for i in zf.infolist() if not i.is_dir()]
    else:
        if torch is None:
            raise RuntimeError("torch is required to list PT container contents.")
        data = torch.load(bundle, map_location="cpu")
        if isinstance(data, dict):
            return sorted(map(str, data.keys()))
        raise ValueError("Unsupported PT container format: expected dict mapping.")


def extract_member(bundle: Path, member: str, out_path: Path) -> None:
    if bundle.suffix.lower() == ".zip":
        with zipfile.ZipFile(bundle, "r") as zf:
            with zf.open(member, "r") as fh, open(out_path, "wb") as out:
                out.write(fh.read())
    else:
        if torch is None:
            raise RuntimeError("torch is required to extract from PT container.")
        data = torch.load(bundle, map_location="cpu")
        if not isinstance(data, dict):
            raise ValueError("Unsupported PT container format: expected dict mapping.")
        if member not in data:
            raise FileNotFoundError(f"member not found in container: {member}")
        out_path.parent.mkdir(parents=True, exist_ok=True)
        with open(out_path, "wb") as fh:
            fh.write(data[member])


def main(argv: List[str] | None = None) -> int:
    parser = argparse.ArgumentParser(description="Bundle multiple .pt files without modifying originals.")
    sub = parser.add_subparsers(dest="cmd", required=True)

    p_zip = sub.add_parser("zip", help="Create a ZIP archive of .pt files.")
    p_zip.add_argument("--source", default=".", help="Root directory to scan for .pt files.")
    p_zip.add_argument("--out", required=True, help="Output ZIP path.")
    p_zip.add_argument("--include", nargs="*", default=["*.pt"], help="Glob patterns to include.")
    p_zip.add_argument("--exclude", nargs="*", default=[], help="Glob patterns to exclude.")

    p_pt = sub.add_parser("pt", help="Create a single .pt container (dict of bytes).")
    p_pt.add_argument("--source", default=".", help="Root directory to scan for .pt files.")
    p_pt.add_argument("--out", required=True, help="Output PT path (e.g., models_multi.pt).")
    p_pt.add_argument("--include", nargs="*", default=["*.pt"], help="Glob patterns to include.")
    p_pt.add_argument("--exclude", nargs="*", default=[], help="Glob patterns to exclude.")

    p_list = sub.add_parser("list", help="List contents of a bundle (ZIP or PT container).")
    p_list.add_argument("--bundle", required=True, help="Path to models_bundle.zip or models_multi.pt.")

    p_ext = sub.add_parser("extract", help="Extract a single member from the bundle.")
    p_ext.add_argument("--bundle", required=True, help="Bundle path (ZIP or PT container).")
    p_ext.add_argument("--member", required=True, help="Member path inside the bundle.")
    p_ext.add_argument("--out", required=True, help="Destination file path to write.")

    args = parser.parse_args(argv)

    if args.cmd == "zip":
        count = create_zip_bundle(Path(args.source), Path(args.out), args.include, args.exclude)
        print(f"ZIP bundle written: {args.out} ({count} files)")
        return 0
    if args.cmd == "pt":
        count = create_pt_container(Path(args.source), Path(args.out), args.include, args.exclude)
        print(f"PT container written: {args.out} ({count} files)")
        return 0
    if args.cmd == "list":
        items = list_bundle(Path(args.bundle))
        for it in items:
            print(it)
        return 0
    if args.cmd == "extract":
        extract_member(Path(args.bundle), args.member, Path(args.out))
        print(f"Extracted {args.member} -> {args.out}")
        return 0

    parser.print_help()
    return 1


if __name__ == "__main__":
    raise SystemExit(main())