| | |
| | """Convert HF release safetensors checkpoints into official repo checkpoint layouts.""" |
| |
|
| | from __future__ import annotations |
| |
|
| | import argparse |
| | from argparse import Namespace |
| | from pathlib import Path |
| |
|
| | import torch |
| | from safetensors.torch import load_file |
| |
|
| |
|
| | def parse_args() -> argparse.Namespace: |
| | parser = argparse.ArgumentParser( |
| | description=( |
| | "Convert a safetensors release checkpoint into the PyTorch checkpoint " |
| | "layout expected by the official RT-DETRv4 or RF-DETR repositories." |
| | ) |
| | ) |
| | parser.add_argument( |
| | "--framework", |
| | choices=("rtdetrv4", "rfdetr"), |
| | required=True, |
| | help="Target official repository format.", |
| | ) |
| | parser.add_argument( |
| | "--input", |
| | type=Path, |
| | required=True, |
| | help="Input .safetensors checkpoint path.", |
| | ) |
| | parser.add_argument( |
| | "--output", |
| | type=Path, |
| | required=True, |
| | help="Output .pth checkpoint path.", |
| | ) |
| | parser.add_argument( |
| | "--class-names", |
| | nargs="+", |
| | default=["person", "head"], |
| | help="Class names to store in RF-DETR checkpoint metadata.", |
| | ) |
| | return parser.parse_args() |
| |
|
| |
|
| | def main() -> None: |
| | args = parse_args() |
| | state_dict = load_file(str(args.input)) |
| | args.output.parent.mkdir(parents=True, exist_ok=True) |
| |
|
| | if args.framework == "rtdetrv4": |
| | payload = {"model": state_dict} |
| | else: |
| | payload = { |
| | "model": state_dict, |
| | "args": Namespace(class_names=args.class_names), |
| | } |
| |
|
| | torch.save(payload, args.output) |
| |
|
| | print(f"Converted {args.input} -> {args.output}") |
| | print(f"Framework: {args.framework}") |
| | print(f"Tensors: {len(state_dict)}") |
| | if args.framework == "rfdetr": |
| | print(f"Class names: {args.class_names}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|