""" Wildlife Detection with YOLOv26 — Inference Script =================================================== Supports RGB and thermal drone imagery. Usage: python inference.py --model rgb --source path/to/image.jpg python inference.py --model thermal_merged --source path/to/thermal/ --save python inference.py --model matched_rgb --source image.jpg --conf 0.3 --show Available models: thermal_original — Baseline thermal model thermal_merged — Refined thermal model (more training data) rgb — Primary RGB model matched_rgb — RGB model trained on matched RGB/thermal pairs matched_thermal — Thermal model trained on matched RGB/thermal pairs """ import argparse from pathlib import Path from ultralytics import YOLO MODELS = { "thermal_original": "thermal_original/weights/best.pt", "thermal_merged": "thermal_merged/weights/best.pt", "rgb": "rgb/weights/best.pt", "matched_rgb": "matched_rgb/weights/best.pt", "matched_thermal": "matched_thermal/weights/best.pt", } def load_model(name: str) -> YOLO: """Load a model by name or direct path.""" path = MODELS.get(name, name) print(f"Loading model: {path}") return YOLO(path) def run_inference( model_name: str = "rgb", source: str = "0", imgsz: int = 1024, conf: float = 0.25, iou: float = 0.45, show: bool = False, save: bool = False, save_txt: bool = False, project: str = "detections", name: str = "predict", device: str = "", ): """Run inference and return results.""" model = load_model(model_name) results = model.predict( source=source, imgsz=imgsz, conf=conf, iou=iou, show=show, save=save, save_txt=save_txt, project=project, name=name, device=device if device else None, ) for i, result in enumerate(results): n = len(result.boxes) print(f"[Image {i+1}] {n} detection(s)") for box in result.boxes: cls_id = int(box.cls.item()) cls_name = result.names[cls_id] conf_val = box.conf.item() xyxy = [round(v, 1) for v in box.xyxy[0].tolist()] print(f" {cls_name:15s} conf={conf_val:.2f} bbox={xyxy}") return results def compare_modalities( rgb_source: str, thermal_source: str, conf: float = 0.25, imgsz: int = 1024, ): """ Compare RGB vs thermal detections on co-registered image pairs. Useful for the matched dataset experiments. """ rgb_model = load_model("matched_rgb") thermal_model = load_model("matched_thermal") rgb_results = rgb_model.predict(rgb_source, imgsz=imgsz, conf=conf, verbose=False) thermal_results = thermal_model.predict(thermal_source, imgsz=imgsz, conf=conf, verbose=False) for i, (r_rgb, r_thm) in enumerate(zip(rgb_results, thermal_results)): print(f"\n--- Pair {i+1} ---") print(f" RGB detections: {len(r_rgb.boxes)}") print(f" Thermal detections: {len(r_thm.boxes)}") return rgb_results, thermal_results def main(): parser = argparse.ArgumentParser(description="Wildlife YOLOv26 Inference") parser.add_argument( "--model", default="rgb", choices=list(MODELS.keys()) + ["custom"], help="Model to use. Pass a file path with --model custom --weights .", ) parser.add_argument("--weights", default=None, help="Direct path to .pt weights file.") parser.add_argument("--source", default="0", help="Image/video/folder path or webcam index.") parser.add_argument("--imgsz", type=int, default=1024, help="Inference image size.") parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold.") parser.add_argument("--iou", type=float, default=0.45, help="NMS IoU threshold.") parser.add_argument("--show", action="store_true", help="Display results.") parser.add_argument("--save", action="store_true", help="Save annotated images.") parser.add_argument("--save-txt", action="store_true", help="Save YOLO-format labels.") parser.add_argument("--project", default="detections", help="Output project folder.") parser.add_argument("--name", default="predict", help="Output run name.") parser.add_argument("--device", default="", help="CUDA device, e.g. '0' or 'cpu'.") parser.add_argument( "--compare", nargs=2, metavar=("RGB_SOURCE", "THERMAL_SOURCE"), help="Compare RGB and thermal models on co-registered pairs.", ) args = parser.parse_args() if args.compare: compare_modalities( rgb_source=args.compare[0], thermal_source=args.compare[1], conf=args.conf, imgsz=args.imgsz, ) return model_name = args.weights if (args.model == "custom" and args.weights) else args.model run_inference( model_name=model_name, source=args.source, imgsz=args.imgsz, conf=args.conf, iou=args.iou, show=args.show, save=args.save, save_txt=args.save_txt, project=args.project, name=args.name, device=args.device, ) if __name__ == "__main__": main()