#!/usr/bin/env python3 """Convert HF release safetensors checkpoints into official repo checkpoint layouts.""" from __future__ import annotations import argparse from argparse import Namespace from pathlib import Path import torch from safetensors.torch import load_file def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description=( "Convert a safetensors release checkpoint into the PyTorch checkpoint " "layout expected by the official RT-DETRv4 or RF-DETR repositories." ) ) parser.add_argument( "--framework", choices=("rtdetrv4", "rfdetr"), required=True, help="Target official repository format.", ) parser.add_argument( "--input", type=Path, required=True, help="Input .safetensors checkpoint path.", ) parser.add_argument( "--output", type=Path, required=True, help="Output .pth checkpoint path.", ) parser.add_argument( "--class-names", nargs="+", default=["person", "head"], help="Class names to store in RF-DETR checkpoint metadata.", ) return parser.parse_args() def main() -> None: args = parse_args() state_dict = load_file(str(args.input)) args.output.parent.mkdir(parents=True, exist_ok=True) if args.framework == "rtdetrv4": payload = {"model": state_dict} else: payload = { "model": state_dict, "args": Namespace(class_names=args.class_names), } torch.save(payload, args.output) print(f"Converted {args.input} -> {args.output}") print(f"Framework: {args.framework}") print(f"Tensors: {len(state_dict)}") if args.framework == "rfdetr": print(f"Class names: {args.class_names}") if __name__ == "__main__": main()