#!/usr/bin/env python3 """Run release checkpoints through the official RT-DETRv4 or RF-DETR repositories.""" from __future__ import annotations import argparse import json import subprocess import sys from argparse import Namespace from pathlib import Path import torch from safetensors.torch import load_file IMAGE_SUFFIXES = {".jpg", ".jpeg", ".png", ".bmp", ".webp"} VIDEO_SUFFIXES = {".mp4", ".avi", ".mov", ".mkv"} DEFAULT_CLASS_NAMES = ["person", "head"] def convert_checkpoint( framework: str, checkpoint_path: Path, output_dir: Path, class_names: list[str] | None = None, ) -> Path: checkpoint_path = checkpoint_path.resolve() output_dir.mkdir(parents=True, exist_ok=True) if checkpoint_path.suffix == ".pth": return checkpoint_path if checkpoint_path.suffix != ".safetensors": raise ValueError(f"Unsupported checkpoint format: {checkpoint_path}") state_dict = load_file(str(checkpoint_path)) output_path = output_dir / f"{checkpoint_path.stem}.pth" if framework == "rtdetrv4": payload = {"model": state_dict} else: payload = { "model": state_dict, "args": Namespace(class_names=class_names or DEFAULT_CLASS_NAMES), } torch.save(payload, output_path) return output_path def infer_teacher_dim(checkpoint_path: Path, explicit: int | None) -> int: if explicit is not None: return explicit checkpoint_path = checkpoint_path.resolve() state_dict = None if checkpoint_path.suffix == ".safetensors": state_dict = load_file(str(checkpoint_path)) elif checkpoint_path.suffix == ".pth": payload = torch.load(checkpoint_path, map_location="cpu", weights_only=False) state_dict = payload["model"] if isinstance(payload, dict) and "model" in payload else payload if isinstance(state_dict, dict) and "encoder.feature_projector.0.weight" in state_dict: return int(state_dict["encoder.feature_projector.0.weight"].shape[0]) name = checkpoint_path.as_posix().lower() if "cradio" in name or "cradiov4" in name or "c-radio" in name: return 1152 return 768 def write_rtdetrv4_config(repo_path: Path, output_dir: Path, teacher_dim: int) -> Path: config_path = output_dir / "rtdetrv4_person_head_inference.yml" base_config = (repo_path / "configs" / "rtv4" / "rtv4_hgnetv2_s_coco.yml").resolve() config_text = ( "__include__: [\n" f" '{base_config}'\n" "]\n\n" "num_classes: 2\n" "remap_mscoco_category: False\n\n" "HGNetv2:\n" " pretrained: False\n\n" "HybridEncoder:\n" f" distill_teacher_dim: {teacher_dim}\n" ) config_path.write_text(config_text) return config_path def run_rtdetrv4(args: argparse.Namespace) -> None: repo_path = args.repo.resolve() input_path = args.input.resolve() output_dir = args.output_dir.resolve() output_dir.mkdir(parents=True, exist_ok=True) converted_ckpt = convert_checkpoint("rtdetrv4", args.checkpoint, output_dir / "artifacts") teacher_dim = infer_teacher_dim(args.checkpoint, args.teacher_dim) config_path = write_rtdetrv4_config(repo_path, output_dir, teacher_dim) command = [ sys.executable, str((repo_path / "tools" / "inference" / "torch_inf.py").resolve()), "-c", str(config_path), "-r", str(converted_ckpt), "-i", str(input_path), "-d", args.device, ] subprocess.run(command, cwd=output_dir, check=True) result_name = "torch_results.jpg" if input_path.suffix.lower() in VIDEO_SUFFIXES: result_name = "torch_results.mp4" print( json.dumps( { "framework": "rtdetrv4", "converted_checkpoint": str(converted_ckpt), "generated_config": str(config_path), "result": str(output_dir / result_name), "teacher_dim": teacher_dim, }, indent=2, ) ) def build_label_lookup(class_names) -> dict[int, str]: """Build an int -> str lookup from whatever format class_names is in.""" if isinstance(class_names, dict): return {int(k): v for k, v in class_names.items()} if isinstance(class_names, (list, tuple)): return {i: name for i, name in enumerate(class_names)} return {} def resolve_class_name(label_lookup: dict[int, str], raw_class_id: int) -> str: if raw_class_id in label_lookup: return label_lookup[raw_class_id] if raw_class_id + 1 in label_lookup: return label_lookup[raw_class_id + 1] return str(raw_class_id) def run_rfdetr(args: argparse.Namespace) -> None: import numpy as np import supervision as sv from PIL import Image repo_path = args.repo.resolve() output_dir = args.output_dir.resolve() input_path = args.input.resolve() output_dir.mkdir(parents=True, exist_ok=True) if input_path.suffix.lower() not in IMAGE_SUFFIXES: raise ValueError("RF-DETR wrapper currently supports image inference only.") sys.path.insert(0, str(repo_path)) sys.path.insert(0, str(repo_path / "src")) from rfdetr import RFDETRSmall converted_ckpt = convert_checkpoint( "rfdetr", args.checkpoint, output_dir / "artifacts", class_names=args.class_names, ) model = RFDETRSmall( pretrain_weights=str(converted_ckpt), device=args.device, ) image = Image.open(input_path).convert("RGB") detections = model.predict(image, threshold=args.threshold) label_lookup = build_label_lookup(getattr(model, "class_names", args.class_names)) labels = [] for class_id, confidence in zip(detections.class_id.tolist(), detections.confidence.tolist()): class_name = resolve_class_name(label_lookup, int(class_id)) labels.append(f"{class_name} {confidence:.2f}") image_np = np.array(image) annotated = sv.BoxAnnotator().annotate(scene=image_np, detections=detections) annotated = sv.LabelAnnotator().annotate(scene=annotated, detections=detections, labels=labels) output_image = output_dir / f"{input_path.stem}_rfdetr.jpg" output_json = output_dir / f"{input_path.stem}_rfdetr.json" Image.fromarray(annotated).save(output_image) predictions = [] for box, confidence, class_id in zip( detections.xyxy.tolist(), detections.confidence.tolist(), detections.class_id.tolist(), ): raw_id = int(class_id) predictions.append( { "bbox_xyxy": [round(float(v), 4) for v in box], "confidence": round(float(confidence), 6), "class_id": raw_id, "class_name": resolve_class_name(label_lookup, raw_id), } ) output_json.write_text(json.dumps(predictions, indent=2)) print( json.dumps( { "framework": "rfdetr", "converted_checkpoint": str(converted_ckpt), "result_image": str(output_image), "result_json": str(output_json), }, indent=2, ) ) def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Run this release through the official RT-DETRv4 or RF-DETR repositories." ) subparsers = parser.add_subparsers(dest="framework", required=True) rtdetr_parser = subparsers.add_parser("rtdetrv4", help="Run official RT-DETRv4 inference.") rtdetr_parser.add_argument("--repo", type=Path, required=True, help="Path to the official RT-DETRv4 repository.") rtdetr_parser.add_argument("--checkpoint", type=Path, required=True, help="Release checkpoint (.safetensors or .pth).") rtdetr_parser.add_argument("--input", type=Path, required=True, help="Input image or video path.") rtdetr_parser.add_argument("--device", default="cpu", help="Inference device passed to official script.") rtdetr_parser.add_argument( "--output-dir", type=Path, default=Path("outputs/rtdetrv4"), help="Directory where converted weights, temp config, and outputs are written.", ) rtdetr_parser.add_argument( "--teacher-dim", type=int, choices=(768, 1152), default=None, help="Override the RT-DETRv4 distillation projection dimension if auto-detection is wrong.", ) rtdetr_parser.set_defaults(func=run_rtdetrv4) rfdetr_parser = subparsers.add_parser("rfdetr", help="Run official RF-DETR inference.") rfdetr_parser.add_argument("--repo", type=Path, required=True, help="Path to the official RF-DETR repository.") rfdetr_parser.add_argument("--checkpoint", type=Path, required=True, help="Release checkpoint (.safetensors or .pth).") rfdetr_parser.add_argument("--input", type=Path, required=True, help="Input image path.") rfdetr_parser.add_argument("--device", default="cpu", help="Device passed to RF-DETR.") rfdetr_parser.add_argument( "--output-dir", type=Path, default=Path("outputs/rfdetr"), help="Directory where converted weights and outputs are written.", ) rfdetr_parser.add_argument("--threshold", type=float, default=0.4, help="Detection threshold.") rfdetr_parser.add_argument( "--class-names", nargs="+", default=DEFAULT_CLASS_NAMES, help="Class names stored in converted RF-DETR checkpoints.", ) rfdetr_parser.set_defaults(func=run_rfdetr) return parser def main() -> None: parser = build_parser() args = parser.parse_args() args.func(args) if __name__ == "__main__": main()