File size: 1,879 Bytes
c9ea96d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | #!/usr/bin/env python3
"""Convert HF release safetensors checkpoints into official repo checkpoint layouts."""
from __future__ import annotations
import argparse
from argparse import Namespace
from pathlib import Path
import torch
from safetensors.torch import load_file
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=(
"Convert a safetensors release checkpoint into the PyTorch checkpoint "
"layout expected by the official RT-DETRv4 or RF-DETR repositories."
)
)
parser.add_argument(
"--framework",
choices=("rtdetrv4", "rfdetr"),
required=True,
help="Target official repository format.",
)
parser.add_argument(
"--input",
type=Path,
required=True,
help="Input .safetensors checkpoint path.",
)
parser.add_argument(
"--output",
type=Path,
required=True,
help="Output .pth checkpoint path.",
)
parser.add_argument(
"--class-names",
nargs="+",
default=["person", "head"],
help="Class names to store in RF-DETR checkpoint metadata.",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
state_dict = load_file(str(args.input))
args.output.parent.mkdir(parents=True, exist_ok=True)
if args.framework == "rtdetrv4":
payload = {"model": state_dict}
else:
payload = {
"model": state_dict,
"args": Namespace(class_names=args.class_names),
}
torch.save(payload, args.output)
print(f"Converted {args.input} -> {args.output}")
print(f"Framework: {args.framework}")
print(f"Tensors: {len(state_dict)}")
if args.framework == "rfdetr":
print(f"Class names: {args.class_names}")
if __name__ == "__main__":
main()
|