| |
| """Run release checkpoints through the official RT-DETRv4 or RF-DETR repositories.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import subprocess |
| import sys |
| from argparse import Namespace |
| from pathlib import Path |
|
|
| import torch |
| from safetensors.torch import load_file |
|
|
|
|
| IMAGE_SUFFIXES = {".jpg", ".jpeg", ".png", ".bmp", ".webp"} |
| VIDEO_SUFFIXES = {".mp4", ".avi", ".mov", ".mkv"} |
| DEFAULT_CLASS_NAMES = ["person", "head"] |
|
|
|
|
| def convert_checkpoint( |
| framework: str, |
| checkpoint_path: Path, |
| output_dir: Path, |
| class_names: list[str] | None = None, |
| ) -> Path: |
| checkpoint_path = checkpoint_path.resolve() |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| if checkpoint_path.suffix == ".pth": |
| return checkpoint_path |
|
|
| if checkpoint_path.suffix != ".safetensors": |
| raise ValueError(f"Unsupported checkpoint format: {checkpoint_path}") |
|
|
| state_dict = load_file(str(checkpoint_path)) |
| output_path = output_dir / f"{checkpoint_path.stem}.pth" |
|
|
| if framework == "rtdetrv4": |
| payload = {"model": state_dict} |
| else: |
| payload = { |
| "model": state_dict, |
| "args": Namespace(class_names=class_names or DEFAULT_CLASS_NAMES), |
| } |
|
|
| torch.save(payload, output_path) |
| return output_path |
|
|
|
|
| def infer_teacher_dim(checkpoint_path: Path, explicit: int | None) -> int: |
| if explicit is not None: |
| return explicit |
|
|
| checkpoint_path = checkpoint_path.resolve() |
| state_dict = None |
| if checkpoint_path.suffix == ".safetensors": |
| state_dict = load_file(str(checkpoint_path)) |
| elif checkpoint_path.suffix == ".pth": |
| payload = torch.load(checkpoint_path, map_location="cpu", weights_only=False) |
| state_dict = payload["model"] if isinstance(payload, dict) and "model" in payload else payload |
|
|
| if isinstance(state_dict, dict) and "encoder.feature_projector.0.weight" in state_dict: |
| return int(state_dict["encoder.feature_projector.0.weight"].shape[0]) |
|
|
| name = checkpoint_path.as_posix().lower() |
| if "cradio" in name or "cradiov4" in name or "c-radio" in name: |
| return 1152 |
| return 768 |
|
|
|
|
| def write_rtdetrv4_config(repo_path: Path, output_dir: Path, teacher_dim: int) -> Path: |
| config_path = output_dir / "rtdetrv4_person_head_inference.yml" |
| base_config = (repo_path / "configs" / "rtv4" / "rtv4_hgnetv2_s_coco.yml").resolve() |
| config_text = ( |
| "__include__: [\n" |
| f" '{base_config}'\n" |
| "]\n\n" |
| "num_classes: 2\n" |
| "remap_mscoco_category: False\n\n" |
| "HGNetv2:\n" |
| " pretrained: False\n\n" |
| "HybridEncoder:\n" |
| f" distill_teacher_dim: {teacher_dim}\n" |
| ) |
| config_path.write_text(config_text) |
| return config_path |
|
|
|
|
| def run_rtdetrv4(args: argparse.Namespace) -> None: |
| repo_path = args.repo.resolve() |
| input_path = args.input.resolve() |
| output_dir = args.output_dir.resolve() |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| converted_ckpt = convert_checkpoint("rtdetrv4", args.checkpoint, output_dir / "artifacts") |
| teacher_dim = infer_teacher_dim(args.checkpoint, args.teacher_dim) |
| config_path = write_rtdetrv4_config(repo_path, output_dir, teacher_dim) |
|
|
| command = [ |
| sys.executable, |
| str((repo_path / "tools" / "inference" / "torch_inf.py").resolve()), |
| "-c", |
| str(config_path), |
| "-r", |
| str(converted_ckpt), |
| "-i", |
| str(input_path), |
| "-d", |
| args.device, |
| ] |
|
|
| subprocess.run(command, cwd=output_dir, check=True) |
|
|
| result_name = "torch_results.jpg" |
| if input_path.suffix.lower() in VIDEO_SUFFIXES: |
| result_name = "torch_results.mp4" |
|
|
| print( |
| json.dumps( |
| { |
| "framework": "rtdetrv4", |
| "converted_checkpoint": str(converted_ckpt), |
| "generated_config": str(config_path), |
| "result": str(output_dir / result_name), |
| "teacher_dim": teacher_dim, |
| }, |
| indent=2, |
| ) |
| ) |
|
|
|
|
| def build_label_lookup(class_names) -> dict[int, str]: |
| """Build an int -> str lookup from whatever format class_names is in.""" |
| if isinstance(class_names, dict): |
| return {int(k): v for k, v in class_names.items()} |
| if isinstance(class_names, (list, tuple)): |
| return {i: name for i, name in enumerate(class_names)} |
| return {} |
|
|
|
|
| def resolve_class_name(label_lookup: dict[int, str], raw_class_id: int) -> str: |
| if raw_class_id in label_lookup: |
| return label_lookup[raw_class_id] |
| if raw_class_id + 1 in label_lookup: |
| return label_lookup[raw_class_id + 1] |
| return str(raw_class_id) |
|
|
|
|
| def run_rfdetr(args: argparse.Namespace) -> None: |
| import numpy as np |
| import supervision as sv |
| from PIL import Image |
|
|
| repo_path = args.repo.resolve() |
| output_dir = args.output_dir.resolve() |
| input_path = args.input.resolve() |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| if input_path.suffix.lower() not in IMAGE_SUFFIXES: |
| raise ValueError("RF-DETR wrapper currently supports image inference only.") |
|
|
| sys.path.insert(0, str(repo_path)) |
| sys.path.insert(0, str(repo_path / "src")) |
|
|
| from rfdetr import RFDETRSmall |
|
|
| converted_ckpt = convert_checkpoint( |
| "rfdetr", |
| args.checkpoint, |
| output_dir / "artifacts", |
| class_names=args.class_names, |
| ) |
|
|
| model = RFDETRSmall( |
| pretrain_weights=str(converted_ckpt), |
| device=args.device, |
| ) |
|
|
| image = Image.open(input_path).convert("RGB") |
| detections = model.predict(image, threshold=args.threshold) |
|
|
| label_lookup = build_label_lookup(getattr(model, "class_names", args.class_names)) |
| labels = [] |
| for class_id, confidence in zip(detections.class_id.tolist(), detections.confidence.tolist()): |
| class_name = resolve_class_name(label_lookup, int(class_id)) |
| labels.append(f"{class_name} {confidence:.2f}") |
|
|
| image_np = np.array(image) |
| annotated = sv.BoxAnnotator().annotate(scene=image_np, detections=detections) |
| annotated = sv.LabelAnnotator().annotate(scene=annotated, detections=detections, labels=labels) |
|
|
| output_image = output_dir / f"{input_path.stem}_rfdetr.jpg" |
| output_json = output_dir / f"{input_path.stem}_rfdetr.json" |
|
|
| Image.fromarray(annotated).save(output_image) |
|
|
| predictions = [] |
| for box, confidence, class_id in zip( |
| detections.xyxy.tolist(), |
| detections.confidence.tolist(), |
| detections.class_id.tolist(), |
| ): |
| raw_id = int(class_id) |
| predictions.append( |
| { |
| "bbox_xyxy": [round(float(v), 4) for v in box], |
| "confidence": round(float(confidence), 6), |
| "class_id": raw_id, |
| "class_name": resolve_class_name(label_lookup, raw_id), |
| } |
| ) |
|
|
| output_json.write_text(json.dumps(predictions, indent=2)) |
|
|
| print( |
| json.dumps( |
| { |
| "framework": "rfdetr", |
| "converted_checkpoint": str(converted_ckpt), |
| "result_image": str(output_image), |
| "result_json": str(output_json), |
| }, |
| indent=2, |
| ) |
| ) |
|
|
|
|
| def build_parser() -> argparse.ArgumentParser: |
| parser = argparse.ArgumentParser( |
| description="Run this release through the official RT-DETRv4 or RF-DETR repositories." |
| ) |
| subparsers = parser.add_subparsers(dest="framework", required=True) |
|
|
| rtdetr_parser = subparsers.add_parser("rtdetrv4", help="Run official RT-DETRv4 inference.") |
| rtdetr_parser.add_argument("--repo", type=Path, required=True, help="Path to the official RT-DETRv4 repository.") |
| rtdetr_parser.add_argument("--checkpoint", type=Path, required=True, help="Release checkpoint (.safetensors or .pth).") |
| rtdetr_parser.add_argument("--input", type=Path, required=True, help="Input image or video path.") |
| rtdetr_parser.add_argument("--device", default="cpu", help="Inference device passed to official script.") |
| rtdetr_parser.add_argument( |
| "--output-dir", |
| type=Path, |
| default=Path("outputs/rtdetrv4"), |
| help="Directory where converted weights, temp config, and outputs are written.", |
| ) |
| rtdetr_parser.add_argument( |
| "--teacher-dim", |
| type=int, |
| choices=(768, 1152), |
| default=None, |
| help="Override the RT-DETRv4 distillation projection dimension if auto-detection is wrong.", |
| ) |
| rtdetr_parser.set_defaults(func=run_rtdetrv4) |
|
|
| rfdetr_parser = subparsers.add_parser("rfdetr", help="Run official RF-DETR inference.") |
| rfdetr_parser.add_argument("--repo", type=Path, required=True, help="Path to the official RF-DETR repository.") |
| rfdetr_parser.add_argument("--checkpoint", type=Path, required=True, help="Release checkpoint (.safetensors or .pth).") |
| rfdetr_parser.add_argument("--input", type=Path, required=True, help="Input image path.") |
| rfdetr_parser.add_argument("--device", default="cpu", help="Device passed to RF-DETR.") |
| rfdetr_parser.add_argument( |
| "--output-dir", |
| type=Path, |
| default=Path("outputs/rfdetr"), |
| help="Directory where converted weights and outputs are written.", |
| ) |
| rfdetr_parser.add_argument("--threshold", type=float, default=0.4, help="Detection threshold.") |
| rfdetr_parser.add_argument( |
| "--class-names", |
| nargs="+", |
| default=DEFAULT_CLASS_NAMES, |
| help="Class names stored in converted RF-DETR checkpoints.", |
| ) |
| rfdetr_parser.set_defaults(func=run_rfdetr) |
|
|
| return parser |
|
|
|
|
| def main() -> None: |
| parser = build_parser() |
| args = parser.parse_args() |
| args.func(args) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|