| | |
| | """Run release checkpoints through the official RT-DETRv4 or RF-DETR repositories.""" |
| |
|
| | from __future__ import annotations |
| |
|
| | import argparse |
| | import json |
| | import subprocess |
| | import sys |
| | from argparse import Namespace |
| | from pathlib import Path |
| |
|
| | import torch |
| | from safetensors.torch import load_file |
| |
|
| |
|
| | IMAGE_SUFFIXES = {".jpg", ".jpeg", ".png", ".bmp", ".webp"} |
| | VIDEO_SUFFIXES = {".mp4", ".avi", ".mov", ".mkv"} |
| | DEFAULT_CLASS_NAMES = ["person", "head"] |
| |
|
| |
|
| | def convert_checkpoint( |
| | framework: str, |
| | checkpoint_path: Path, |
| | output_dir: Path, |
| | class_names: list[str] | None = None, |
| | ) -> Path: |
| | checkpoint_path = checkpoint_path.resolve() |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | if checkpoint_path.suffix == ".pth": |
| | return checkpoint_path |
| |
|
| | if checkpoint_path.suffix != ".safetensors": |
| | raise ValueError(f"Unsupported checkpoint format: {checkpoint_path}") |
| |
|
| | state_dict = load_file(str(checkpoint_path)) |
| | output_path = output_dir / f"{checkpoint_path.stem}.pth" |
| |
|
| | if framework == "rtdetrv4": |
| | payload = {"model": state_dict} |
| | else: |
| | payload = { |
| | "model": state_dict, |
| | "args": Namespace(class_names=class_names or DEFAULT_CLASS_NAMES), |
| | } |
| |
|
| | torch.save(payload, output_path) |
| | return output_path |
| |
|
| |
|
| | def infer_teacher_dim(checkpoint_path: Path, explicit: int | None) -> int: |
| | if explicit is not None: |
| | return explicit |
| |
|
| | checkpoint_path = checkpoint_path.resolve() |
| | state_dict = None |
| | if checkpoint_path.suffix == ".safetensors": |
| | state_dict = load_file(str(checkpoint_path)) |
| | elif checkpoint_path.suffix == ".pth": |
| | payload = torch.load(checkpoint_path, map_location="cpu", weights_only=False) |
| | state_dict = payload["model"] if isinstance(payload, dict) and "model" in payload else payload |
| |
|
| | if isinstance(state_dict, dict) and "encoder.feature_projector.0.weight" in state_dict: |
| | return int(state_dict["encoder.feature_projector.0.weight"].shape[0]) |
| |
|
| | name = checkpoint_path.as_posix().lower() |
| | if "cradio" in name or "cradiov4" in name or "c-radio" in name: |
| | return 1152 |
| | return 768 |
| |
|
| |
|
| | def write_rtdetrv4_config(repo_path: Path, output_dir: Path, teacher_dim: int) -> Path: |
| | config_path = output_dir / "rtdetrv4_person_head_inference.yml" |
| | base_config = (repo_path / "configs" / "rtv4" / "rtv4_hgnetv2_s_coco.yml").resolve() |
| | config_text = ( |
| | "__include__: [\n" |
| | f" '{base_config}'\n" |
| | "]\n\n" |
| | "num_classes: 2\n" |
| | "remap_mscoco_category: False\n\n" |
| | "HGNetv2:\n" |
| | " pretrained: False\n\n" |
| | "HybridEncoder:\n" |
| | f" distill_teacher_dim: {teacher_dim}\n" |
| | ) |
| | config_path.write_text(config_text) |
| | return config_path |
| |
|
| |
|
| | def run_rtdetrv4(args: argparse.Namespace) -> None: |
| | repo_path = args.repo.resolve() |
| | input_path = args.input.resolve() |
| | output_dir = args.output_dir.resolve() |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | converted_ckpt = convert_checkpoint("rtdetrv4", args.checkpoint, output_dir / "artifacts") |
| | teacher_dim = infer_teacher_dim(args.checkpoint, args.teacher_dim) |
| | config_path = write_rtdetrv4_config(repo_path, output_dir, teacher_dim) |
| |
|
| | command = [ |
| | sys.executable, |
| | str((repo_path / "tools" / "inference" / "torch_inf.py").resolve()), |
| | "-c", |
| | str(config_path), |
| | "-r", |
| | str(converted_ckpt), |
| | "-i", |
| | str(input_path), |
| | "-d", |
| | args.device, |
| | ] |
| |
|
| | subprocess.run(command, cwd=output_dir, check=True) |
| |
|
| | result_name = "torch_results.jpg" |
| | if input_path.suffix.lower() in VIDEO_SUFFIXES: |
| | result_name = "torch_results.mp4" |
| |
|
| | print( |
| | json.dumps( |
| | { |
| | "framework": "rtdetrv4", |
| | "converted_checkpoint": str(converted_ckpt), |
| | "generated_config": str(config_path), |
| | "result": str(output_dir / result_name), |
| | "teacher_dim": teacher_dim, |
| | }, |
| | indent=2, |
| | ) |
| | ) |
| |
|
| |
|
| | def build_label_lookup(class_names) -> dict[int, str]: |
| | """Build an int -> str lookup from whatever format class_names is in.""" |
| | if isinstance(class_names, dict): |
| | return {int(k): v for k, v in class_names.items()} |
| | if isinstance(class_names, (list, tuple)): |
| | return {i: name for i, name in enumerate(class_names)} |
| | return {} |
| |
|
| |
|
| | def resolve_class_name(label_lookup: dict[int, str], raw_class_id: int) -> str: |
| | if raw_class_id in label_lookup: |
| | return label_lookup[raw_class_id] |
| | if raw_class_id + 1 in label_lookup: |
| | return label_lookup[raw_class_id + 1] |
| | return str(raw_class_id) |
| |
|
| |
|
| | def run_rfdetr(args: argparse.Namespace) -> None: |
| | import numpy as np |
| | import supervision as sv |
| | from PIL import Image |
| |
|
| | repo_path = args.repo.resolve() |
| | output_dir = args.output_dir.resolve() |
| | input_path = args.input.resolve() |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | if input_path.suffix.lower() not in IMAGE_SUFFIXES: |
| | raise ValueError("RF-DETR wrapper currently supports image inference only.") |
| |
|
| | sys.path.insert(0, str(repo_path)) |
| | sys.path.insert(0, str(repo_path / "src")) |
| |
|
| | from rfdetr import RFDETRSmall |
| |
|
| | converted_ckpt = convert_checkpoint( |
| | "rfdetr", |
| | args.checkpoint, |
| | output_dir / "artifacts", |
| | class_names=args.class_names, |
| | ) |
| |
|
| | model = RFDETRSmall( |
| | pretrain_weights=str(converted_ckpt), |
| | device=args.device, |
| | ) |
| |
|
| | image = Image.open(input_path).convert("RGB") |
| | detections = model.predict(image, threshold=args.threshold) |
| |
|
| | label_lookup = build_label_lookup(getattr(model, "class_names", args.class_names)) |
| | labels = [] |
| | for class_id, confidence in zip(detections.class_id.tolist(), detections.confidence.tolist()): |
| | class_name = resolve_class_name(label_lookup, int(class_id)) |
| | labels.append(f"{class_name} {confidence:.2f}") |
| |
|
| | image_np = np.array(image) |
| | annotated = sv.BoxAnnotator().annotate(scene=image_np, detections=detections) |
| | annotated = sv.LabelAnnotator().annotate(scene=annotated, detections=detections, labels=labels) |
| |
|
| | output_image = output_dir / f"{input_path.stem}_rfdetr.jpg" |
| | output_json = output_dir / f"{input_path.stem}_rfdetr.json" |
| |
|
| | Image.fromarray(annotated).save(output_image) |
| |
|
| | predictions = [] |
| | for box, confidence, class_id in zip( |
| | detections.xyxy.tolist(), |
| | detections.confidence.tolist(), |
| | detections.class_id.tolist(), |
| | ): |
| | raw_id = int(class_id) |
| | predictions.append( |
| | { |
| | "bbox_xyxy": [round(float(v), 4) for v in box], |
| | "confidence": round(float(confidence), 6), |
| | "class_id": raw_id, |
| | "class_name": resolve_class_name(label_lookup, raw_id), |
| | } |
| | ) |
| |
|
| | output_json.write_text(json.dumps(predictions, indent=2)) |
| |
|
| | print( |
| | json.dumps( |
| | { |
| | "framework": "rfdetr", |
| | "converted_checkpoint": str(converted_ckpt), |
| | "result_image": str(output_image), |
| | "result_json": str(output_json), |
| | }, |
| | indent=2, |
| | ) |
| | ) |
| |
|
| |
|
| | def build_parser() -> argparse.ArgumentParser: |
| | parser = argparse.ArgumentParser( |
| | description="Run this release through the official RT-DETRv4 or RF-DETR repositories." |
| | ) |
| | subparsers = parser.add_subparsers(dest="framework", required=True) |
| |
|
| | rtdetr_parser = subparsers.add_parser("rtdetrv4", help="Run official RT-DETRv4 inference.") |
| | rtdetr_parser.add_argument("--repo", type=Path, required=True, help="Path to the official RT-DETRv4 repository.") |
| | rtdetr_parser.add_argument("--checkpoint", type=Path, required=True, help="Release checkpoint (.safetensors or .pth).") |
| | rtdetr_parser.add_argument("--input", type=Path, required=True, help="Input image or video path.") |
| | rtdetr_parser.add_argument("--device", default="cpu", help="Inference device passed to official script.") |
| | rtdetr_parser.add_argument( |
| | "--output-dir", |
| | type=Path, |
| | default=Path("outputs/rtdetrv4"), |
| | help="Directory where converted weights, temp config, and outputs are written.", |
| | ) |
| | rtdetr_parser.add_argument( |
| | "--teacher-dim", |
| | type=int, |
| | choices=(768, 1152), |
| | default=None, |
| | help="Override the RT-DETRv4 distillation projection dimension if auto-detection is wrong.", |
| | ) |
| | rtdetr_parser.set_defaults(func=run_rtdetrv4) |
| |
|
| | rfdetr_parser = subparsers.add_parser("rfdetr", help="Run official RF-DETR inference.") |
| | rfdetr_parser.add_argument("--repo", type=Path, required=True, help="Path to the official RF-DETR repository.") |
| | rfdetr_parser.add_argument("--checkpoint", type=Path, required=True, help="Release checkpoint (.safetensors or .pth).") |
| | rfdetr_parser.add_argument("--input", type=Path, required=True, help="Input image path.") |
| | rfdetr_parser.add_argument("--device", default="cpu", help="Device passed to RF-DETR.") |
| | rfdetr_parser.add_argument( |
| | "--output-dir", |
| | type=Path, |
| | default=Path("outputs/rfdetr"), |
| | help="Directory where converted weights and outputs are written.", |
| | ) |
| | rfdetr_parser.add_argument("--threshold", type=float, default=0.4, help="Detection threshold.") |
| | rfdetr_parser.add_argument( |
| | "--class-names", |
| | nargs="+", |
| | default=DEFAULT_CLASS_NAMES, |
| | help="Class names stored in converted RF-DETR checkpoints.", |
| | ) |
| | rfdetr_parser.set_defaults(func=run_rfdetr) |
| |
|
| | return parser |
| |
|
| |
|
| | def main() -> None: |
| | parser = build_parser() |
| | args = parser.parse_args() |
| | args.func(args) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|