File size: 1,879 Bytes
c9ea96d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#!/usr/bin/env python3
"""Convert HF release safetensors checkpoints into official repo checkpoint layouts."""

from __future__ import annotations

import argparse
from argparse import Namespace
from pathlib import Path

import torch
from safetensors.torch import load_file


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description=(
            "Convert a safetensors release checkpoint into the PyTorch checkpoint "
            "layout expected by the official RT-DETRv4 or RF-DETR repositories."
        )
    )
    parser.add_argument(
        "--framework",
        choices=("rtdetrv4", "rfdetr"),
        required=True,
        help="Target official repository format.",
    )
    parser.add_argument(
        "--input",
        type=Path,
        required=True,
        help="Input .safetensors checkpoint path.",
    )
    parser.add_argument(
        "--output",
        type=Path,
        required=True,
        help="Output .pth checkpoint path.",
    )
    parser.add_argument(
        "--class-names",
        nargs="+",
        default=["person", "head"],
        help="Class names to store in RF-DETR checkpoint metadata.",
    )
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    state_dict = load_file(str(args.input))
    args.output.parent.mkdir(parents=True, exist_ok=True)

    if args.framework == "rtdetrv4":
        payload = {"model": state_dict}
    else:
        payload = {
            "model": state_dict,
            "args": Namespace(class_names=args.class_names),
        }

    torch.save(payload, args.output)

    print(f"Converted {args.input} -> {args.output}")
    print(f"Framework: {args.framework}")
    print(f"Tensors: {len(state_dict)}")
    if args.framework == "rfdetr":
        print(f"Class names: {args.class_names}")


if __name__ == "__main__":
    main()