NikhilSandy's picture
Initial model release
c9ea96d verified
#!/usr/bin/env python3
"""Run release checkpoints through the official RT-DETRv4 or RF-DETR repositories."""
from __future__ import annotations
import argparse
import json
import subprocess
import sys
from argparse import Namespace
from pathlib import Path
import torch
from safetensors.torch import load_file
IMAGE_SUFFIXES = {".jpg", ".jpeg", ".png", ".bmp", ".webp"}
VIDEO_SUFFIXES = {".mp4", ".avi", ".mov", ".mkv"}
DEFAULT_CLASS_NAMES = ["person", "head"]
def convert_checkpoint(
framework: str,
checkpoint_path: Path,
output_dir: Path,
class_names: list[str] | None = None,
) -> Path:
checkpoint_path = checkpoint_path.resolve()
output_dir.mkdir(parents=True, exist_ok=True)
if checkpoint_path.suffix == ".pth":
return checkpoint_path
if checkpoint_path.suffix != ".safetensors":
raise ValueError(f"Unsupported checkpoint format: {checkpoint_path}")
state_dict = load_file(str(checkpoint_path))
output_path = output_dir / f"{checkpoint_path.stem}.pth"
if framework == "rtdetrv4":
payload = {"model": state_dict}
else:
payload = {
"model": state_dict,
"args": Namespace(class_names=class_names or DEFAULT_CLASS_NAMES),
}
torch.save(payload, output_path)
return output_path
def infer_teacher_dim(checkpoint_path: Path, explicit: int | None) -> int:
if explicit is not None:
return explicit
checkpoint_path = checkpoint_path.resolve()
state_dict = None
if checkpoint_path.suffix == ".safetensors":
state_dict = load_file(str(checkpoint_path))
elif checkpoint_path.suffix == ".pth":
payload = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
state_dict = payload["model"] if isinstance(payload, dict) and "model" in payload else payload
if isinstance(state_dict, dict) and "encoder.feature_projector.0.weight" in state_dict:
return int(state_dict["encoder.feature_projector.0.weight"].shape[0])
name = checkpoint_path.as_posix().lower()
if "cradio" in name or "cradiov4" in name or "c-radio" in name:
return 1152
return 768
def write_rtdetrv4_config(repo_path: Path, output_dir: Path, teacher_dim: int) -> Path:
config_path = output_dir / "rtdetrv4_person_head_inference.yml"
base_config = (repo_path / "configs" / "rtv4" / "rtv4_hgnetv2_s_coco.yml").resolve()
config_text = (
"__include__: [\n"
f" '{base_config}'\n"
"]\n\n"
"num_classes: 2\n"
"remap_mscoco_category: False\n\n"
"HGNetv2:\n"
" pretrained: False\n\n"
"HybridEncoder:\n"
f" distill_teacher_dim: {teacher_dim}\n"
)
config_path.write_text(config_text)
return config_path
def run_rtdetrv4(args: argparse.Namespace) -> None:
repo_path = args.repo.resolve()
input_path = args.input.resolve()
output_dir = args.output_dir.resolve()
output_dir.mkdir(parents=True, exist_ok=True)
converted_ckpt = convert_checkpoint("rtdetrv4", args.checkpoint, output_dir / "artifacts")
teacher_dim = infer_teacher_dim(args.checkpoint, args.teacher_dim)
config_path = write_rtdetrv4_config(repo_path, output_dir, teacher_dim)
command = [
sys.executable,
str((repo_path / "tools" / "inference" / "torch_inf.py").resolve()),
"-c",
str(config_path),
"-r",
str(converted_ckpt),
"-i",
str(input_path),
"-d",
args.device,
]
subprocess.run(command, cwd=output_dir, check=True)
result_name = "torch_results.jpg"
if input_path.suffix.lower() in VIDEO_SUFFIXES:
result_name = "torch_results.mp4"
print(
json.dumps(
{
"framework": "rtdetrv4",
"converted_checkpoint": str(converted_ckpt),
"generated_config": str(config_path),
"result": str(output_dir / result_name),
"teacher_dim": teacher_dim,
},
indent=2,
)
)
def build_label_lookup(class_names) -> dict[int, str]:
"""Build an int -> str lookup from whatever format class_names is in."""
if isinstance(class_names, dict):
return {int(k): v for k, v in class_names.items()}
if isinstance(class_names, (list, tuple)):
return {i: name for i, name in enumerate(class_names)}
return {}
def resolve_class_name(label_lookup: dict[int, str], raw_class_id: int) -> str:
if raw_class_id in label_lookup:
return label_lookup[raw_class_id]
if raw_class_id + 1 in label_lookup:
return label_lookup[raw_class_id + 1]
return str(raw_class_id)
def run_rfdetr(args: argparse.Namespace) -> None:
import numpy as np
import supervision as sv
from PIL import Image
repo_path = args.repo.resolve()
output_dir = args.output_dir.resolve()
input_path = args.input.resolve()
output_dir.mkdir(parents=True, exist_ok=True)
if input_path.suffix.lower() not in IMAGE_SUFFIXES:
raise ValueError("RF-DETR wrapper currently supports image inference only.")
sys.path.insert(0, str(repo_path))
sys.path.insert(0, str(repo_path / "src"))
from rfdetr import RFDETRSmall
converted_ckpt = convert_checkpoint(
"rfdetr",
args.checkpoint,
output_dir / "artifacts",
class_names=args.class_names,
)
model = RFDETRSmall(
pretrain_weights=str(converted_ckpt),
device=args.device,
)
image = Image.open(input_path).convert("RGB")
detections = model.predict(image, threshold=args.threshold)
label_lookup = build_label_lookup(getattr(model, "class_names", args.class_names))
labels = []
for class_id, confidence in zip(detections.class_id.tolist(), detections.confidence.tolist()):
class_name = resolve_class_name(label_lookup, int(class_id))
labels.append(f"{class_name} {confidence:.2f}")
image_np = np.array(image)
annotated = sv.BoxAnnotator().annotate(scene=image_np, detections=detections)
annotated = sv.LabelAnnotator().annotate(scene=annotated, detections=detections, labels=labels)
output_image = output_dir / f"{input_path.stem}_rfdetr.jpg"
output_json = output_dir / f"{input_path.stem}_rfdetr.json"
Image.fromarray(annotated).save(output_image)
predictions = []
for box, confidence, class_id in zip(
detections.xyxy.tolist(),
detections.confidence.tolist(),
detections.class_id.tolist(),
):
raw_id = int(class_id)
predictions.append(
{
"bbox_xyxy": [round(float(v), 4) for v in box],
"confidence": round(float(confidence), 6),
"class_id": raw_id,
"class_name": resolve_class_name(label_lookup, raw_id),
}
)
output_json.write_text(json.dumps(predictions, indent=2))
print(
json.dumps(
{
"framework": "rfdetr",
"converted_checkpoint": str(converted_ckpt),
"result_image": str(output_image),
"result_json": str(output_json),
},
indent=2,
)
)
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Run this release through the official RT-DETRv4 or RF-DETR repositories."
)
subparsers = parser.add_subparsers(dest="framework", required=True)
rtdetr_parser = subparsers.add_parser("rtdetrv4", help="Run official RT-DETRv4 inference.")
rtdetr_parser.add_argument("--repo", type=Path, required=True, help="Path to the official RT-DETRv4 repository.")
rtdetr_parser.add_argument("--checkpoint", type=Path, required=True, help="Release checkpoint (.safetensors or .pth).")
rtdetr_parser.add_argument("--input", type=Path, required=True, help="Input image or video path.")
rtdetr_parser.add_argument("--device", default="cpu", help="Inference device passed to official script.")
rtdetr_parser.add_argument(
"--output-dir",
type=Path,
default=Path("outputs/rtdetrv4"),
help="Directory where converted weights, temp config, and outputs are written.",
)
rtdetr_parser.add_argument(
"--teacher-dim",
type=int,
choices=(768, 1152),
default=None,
help="Override the RT-DETRv4 distillation projection dimension if auto-detection is wrong.",
)
rtdetr_parser.set_defaults(func=run_rtdetrv4)
rfdetr_parser = subparsers.add_parser("rfdetr", help="Run official RF-DETR inference.")
rfdetr_parser.add_argument("--repo", type=Path, required=True, help="Path to the official RF-DETR repository.")
rfdetr_parser.add_argument("--checkpoint", type=Path, required=True, help="Release checkpoint (.safetensors or .pth).")
rfdetr_parser.add_argument("--input", type=Path, required=True, help="Input image path.")
rfdetr_parser.add_argument("--device", default="cpu", help="Device passed to RF-DETR.")
rfdetr_parser.add_argument(
"--output-dir",
type=Path,
default=Path("outputs/rfdetr"),
help="Directory where converted weights and outputs are written.",
)
rfdetr_parser.add_argument("--threshold", type=float, default=0.4, help="Detection threshold.")
rfdetr_parser.add_argument(
"--class-names",
nargs="+",
default=DEFAULT_CLASS_NAMES,
help="Class names stored in converted RF-DETR checkpoints.",
)
rfdetr_parser.set_defaults(func=run_rfdetr)
return parser
def main() -> None:
parser = build_parser()
args = parser.parse_args()
args.func(args)
if __name__ == "__main__":
main()