Spaces:
Sleeping
Sleeping
| """Run inference with a trained RF-DETR checkpoint on aerial images.""" | |
| import argparse | |
| from pathlib import Path | |
| import cv2 | |
| import rfdetr | |
| import supervision as sv | |
| _DIR = Path(__file__).resolve().parent | |
| MODEL_CLASSES: dict[str, type] = { | |
| "nano": rfdetr.RFDETRNano, | |
| "small": rfdetr.RFDETRSmall, | |
| "base": rfdetr.RFDETRBase, | |
| "medium": rfdetr.RFDETRMedium, | |
| "large": rfdetr.RFDETRLarge, | |
| } | |
| prediction_classes = {0: "empty_spot", 1: "parked_car"} | |
| def run_inference( | |
| image_paths: list[Path], | |
| checkpoint: str | Path, | |
| model_size: str = "medium", | |
| threshold: float = 0.5, | |
| output_dir: str | Path = "./inference_output2", | |
| ) -> None: | |
| """Load an RF-DETR checkpoint and run detection on input images. | |
| Args: | |
| image_paths: Paths to input images. | |
| checkpoint: Path to trained checkpoint file. | |
| model_size: Model variant (must match training size). | |
| threshold: Confidence threshold for detections. | |
| output_dir: Directory to save annotated images. | |
| """ | |
| model_cls = MODEL_CLASSES.get(model_size) | |
| if model_cls is None: | |
| raise ValueError( | |
| f"Unknown model_size {model_size!r}, " | |
| f"choose from: {', '.join(MODEL_CLASSES)}" | |
| ) | |
| model = model_cls(pretrain_weights=str(checkpoint)) | |
| output_path = Path(output_dir) | |
| output_path.mkdir(parents=True, exist_ok=True) | |
| box_annotator = sv.BoxAnnotator() | |
| label_annotator = sv.LabelAnnotator() | |
| for image_path in image_paths: | |
| detections: sv.Detections = model.predict(str(image_path), threshold=threshold) | |
| image = cv2.imread(str(image_path)) | |
| labels = [ | |
| f"{prediction_classes[detections.class_id[i]]} {conf:.2f}" | |
| for i, conf in enumerate(detections.confidence) | |
| ] | |
| annotated = box_annotator.annotate(scene=image.copy(), detections=detections) | |
| annotated = label_annotator.annotate( | |
| scene=annotated, detections=detections, labels=labels | |
| ) | |
| out_file = output_path / image_path.name | |
| cv2.imwrite(str(out_file), annotated) | |
| print(f"{image_path.name}: {len(detections)} detections -> {out_file}") | |
| def main() -> None: | |
| parser = argparse.ArgumentParser( | |
| description="Run RF-DETR inference on aerial images" | |
| ) | |
| parser.add_argument( | |
| "images", | |
| nargs="+", | |
| help="Input image path(s)", | |
| ) | |
| parser.add_argument( | |
| "--checkpoint", | |
| type=str, | |
| default=str(_DIR / "output" / "checkpoint_best_regular.pth"), | |
| help="Path to trained checkpoint", | |
| ) | |
| parser.add_argument( | |
| "--model-size", | |
| type=str, | |
| default="medium", | |
| choices=list(MODEL_CLASSES), | |
| help="Model size, must match training (default: medium)", | |
| ) | |
| parser.add_argument( | |
| "--threshold", | |
| type=float, | |
| default=0.5, | |
| help="Confidence threshold (default: 0.5)", | |
| ) | |
| parser.add_argument( | |
| "--output-dir", | |
| type=str, | |
| default=str(_DIR / "inference_output"), | |
| help="Output directory for annotated images", | |
| ) | |
| args = parser.parse_args() | |
| image_paths = [Path(p) for p in args.images] | |
| missing = [p for p in image_paths if not p.exists()] | |
| if missing: | |
| parser.error(f"Image(s) not found: {', '.join(str(p) for p in missing)}") | |
| run_inference( | |
| image_paths=image_paths, | |
| checkpoint=args.checkpoint, | |
| model_size=args.model_size, | |
| threshold=args.threshold, | |
| output_dir=args.output_dir, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |