person_and_head_detection / scripts /convert_release_checkpoint.py
NikhilSandy's picture
Initial model release
c9ea96d verified
#!/usr/bin/env python3
"""Convert HF release safetensors checkpoints into official repo checkpoint layouts."""
from __future__ import annotations
import argparse
from argparse import Namespace
from pathlib import Path
import torch
from safetensors.torch import load_file
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=(
"Convert a safetensors release checkpoint into the PyTorch checkpoint "
"layout expected by the official RT-DETRv4 or RF-DETR repositories."
)
)
parser.add_argument(
"--framework",
choices=("rtdetrv4", "rfdetr"),
required=True,
help="Target official repository format.",
)
parser.add_argument(
"--input",
type=Path,
required=True,
help="Input .safetensors checkpoint path.",
)
parser.add_argument(
"--output",
type=Path,
required=True,
help="Output .pth checkpoint path.",
)
parser.add_argument(
"--class-names",
nargs="+",
default=["person", "head"],
help="Class names to store in RF-DETR checkpoint metadata.",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
state_dict = load_file(str(args.input))
args.output.parent.mkdir(parents=True, exist_ok=True)
if args.framework == "rtdetrv4":
payload = {"model": state_dict}
else:
payload = {
"model": state_dict,
"args": Namespace(class_names=args.class_names),
}
torch.save(payload, args.output)
print(f"Converted {args.input} -> {args.output}")
print(f"Framework: {args.framework}")
print(f"Tensors: {len(state_dict)}")
if args.framework == "rfdetr":
print(f"Class names: {args.class_names}")
if __name__ == "__main__":
main()