car-detection / training /inference.py
socks22's picture
fix detect
a2e9c4d
"""Run inference with a trained RF-DETR checkpoint on aerial images."""
import argparse
from pathlib import Path
import cv2
import rfdetr
import supervision as sv
_DIR = Path(__file__).resolve().parent
MODEL_CLASSES: dict[str, type] = {
"nano": rfdetr.RFDETRNano,
"small": rfdetr.RFDETRSmall,
"base": rfdetr.RFDETRBase,
"medium": rfdetr.RFDETRMedium,
"large": rfdetr.RFDETRLarge,
}
prediction_classes = {0: "empty_spot", 1: "parked_car"}
def run_inference(
image_paths: list[Path],
checkpoint: str | Path,
model_size: str = "medium",
threshold: float = 0.5,
output_dir: str | Path = "./inference_output2",
) -> None:
"""Load an RF-DETR checkpoint and run detection on input images.
Args:
image_paths: Paths to input images.
checkpoint: Path to trained checkpoint file.
model_size: Model variant (must match training size).
threshold: Confidence threshold for detections.
output_dir: Directory to save annotated images.
"""
model_cls = MODEL_CLASSES.get(model_size)
if model_cls is None:
raise ValueError(
f"Unknown model_size {model_size!r}, "
f"choose from: {', '.join(MODEL_CLASSES)}"
)
model = model_cls(pretrain_weights=str(checkpoint))
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
box_annotator = sv.BoxAnnotator()
label_annotator = sv.LabelAnnotator()
for image_path in image_paths:
detections: sv.Detections = model.predict(str(image_path), threshold=threshold)
image = cv2.imread(str(image_path))
labels = [
f"{prediction_classes[detections.class_id[i]]} {conf:.2f}"
for i, conf in enumerate(detections.confidence)
]
annotated = box_annotator.annotate(scene=image.copy(), detections=detections)
annotated = label_annotator.annotate(
scene=annotated, detections=detections, labels=labels
)
out_file = output_path / image_path.name
cv2.imwrite(str(out_file), annotated)
print(f"{image_path.name}: {len(detections)} detections -> {out_file}")
def main() -> None:
parser = argparse.ArgumentParser(
description="Run RF-DETR inference on aerial images"
)
parser.add_argument(
"images",
nargs="+",
help="Input image path(s)",
)
parser.add_argument(
"--checkpoint",
type=str,
default=str(_DIR / "output" / "checkpoint_best_regular.pth"),
help="Path to trained checkpoint",
)
parser.add_argument(
"--model-size",
type=str,
default="medium",
choices=list(MODEL_CLASSES),
help="Model size, must match training (default: medium)",
)
parser.add_argument(
"--threshold",
type=float,
default=0.5,
help="Confidence threshold (default: 0.5)",
)
parser.add_argument(
"--output-dir",
type=str,
default=str(_DIR / "inference_output"),
help="Output directory for annotated images",
)
args = parser.parse_args()
image_paths = [Path(p) for p in args.images]
missing = [p for p in image_paths if not p.exists()]
if missing:
parser.error(f"Image(s) not found: {', '.join(str(p) for p in missing)}")
run_inference(
image_paths=image_paths,
checkpoint=args.checkpoint,
model_size=args.model_size,
threshold=args.threshold,
output_dir=args.output_dir,
)
if __name__ == "__main__":
main()