"""Run inference with a trained RF-DETR checkpoint on aerial images.""" import argparse from pathlib import Path import cv2 import rfdetr import supervision as sv _DIR = Path(__file__).resolve().parent MODEL_CLASSES: dict[str, type] = { "nano": rfdetr.RFDETRNano, "small": rfdetr.RFDETRSmall, "base": rfdetr.RFDETRBase, "medium": rfdetr.RFDETRMedium, "large": rfdetr.RFDETRLarge, } prediction_classes = {0: "empty_spot", 1: "parked_car"} def run_inference( image_paths: list[Path], checkpoint: str | Path, model_size: str = "medium", threshold: float = 0.5, output_dir: str | Path = "./inference_output2", ) -> None: """Load an RF-DETR checkpoint and run detection on input images. Args: image_paths: Paths to input images. checkpoint: Path to trained checkpoint file. model_size: Model variant (must match training size). threshold: Confidence threshold for detections. output_dir: Directory to save annotated images. """ model_cls = MODEL_CLASSES.get(model_size) if model_cls is None: raise ValueError( f"Unknown model_size {model_size!r}, " f"choose from: {', '.join(MODEL_CLASSES)}" ) model = model_cls(pretrain_weights=str(checkpoint)) output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) box_annotator = sv.BoxAnnotator() label_annotator = sv.LabelAnnotator() for image_path in image_paths: detections: sv.Detections = model.predict(str(image_path), threshold=threshold) image = cv2.imread(str(image_path)) labels = [ f"{prediction_classes[detections.class_id[i]]} {conf:.2f}" for i, conf in enumerate(detections.confidence) ] annotated = box_annotator.annotate(scene=image.copy(), detections=detections) annotated = label_annotator.annotate( scene=annotated, detections=detections, labels=labels ) out_file = output_path / image_path.name cv2.imwrite(str(out_file), annotated) print(f"{image_path.name}: {len(detections)} detections -> {out_file}") def main() -> None: parser = argparse.ArgumentParser( description="Run RF-DETR inference on aerial images" ) parser.add_argument( "images", nargs="+", help="Input image path(s)", ) parser.add_argument( "--checkpoint", type=str, default=str(_DIR / "output" / "checkpoint_best_regular.pth"), help="Path to trained checkpoint", ) parser.add_argument( "--model-size", type=str, default="medium", choices=list(MODEL_CLASSES), help="Model size, must match training (default: medium)", ) parser.add_argument( "--threshold", type=float, default=0.5, help="Confidence threshold (default: 0.5)", ) parser.add_argument( "--output-dir", type=str, default=str(_DIR / "inference_output"), help="Output directory for annotated images", ) args = parser.parse_args() image_paths = [Path(p) for p in args.images] missing = [p for p in image_paths if not p.exists()] if missing: parser.error(f"Image(s) not found: {', '.join(str(p) for p in missing)}") run_inference( image_paths=image_paths, checkpoint=args.checkpoint, model_size=args.model_size, threshold=args.threshold, output_dir=args.output_dir, ) if __name__ == "__main__": main()