| """ | |
| Wildlife Detection with YOLOv26 — Inference Script | |
| =================================================== | |
| Supports RGB and thermal drone imagery. | |
| Usage: | |
| python inference.py --model rgb --source path/to/image.jpg | |
| python inference.py --model thermal_merged --source path/to/thermal/ --save | |
| python inference.py --model matched_rgb --source image.jpg --conf 0.3 --show | |
| Available models: | |
| thermal_original — Baseline thermal model | |
| thermal_merged — Refined thermal model (more training data) | |
| rgb — Primary RGB model | |
| matched_rgb — RGB model trained on matched RGB/thermal pairs | |
| matched_thermal — Thermal model trained on matched RGB/thermal pairs | |
| """ | |
| import argparse | |
| from pathlib import Path | |
| from ultralytics import YOLO | |
| MODELS = { | |
| "thermal_original": "thermal_original/weights/best.pt", | |
| "thermal_merged": "thermal_merged/weights/best.pt", | |
| "rgb": "rgb/weights/best.pt", | |
| "matched_rgb": "matched_rgb/weights/best.pt", | |
| "matched_thermal": "matched_thermal/weights/best.pt", | |
| } | |
| def load_model(name: str) -> YOLO: | |
| """Load a model by name or direct path.""" | |
| path = MODELS.get(name, name) | |
| print(f"Loading model: {path}") | |
| return YOLO(path) | |
| def run_inference( | |
| model_name: str = "rgb", | |
| source: str = "0", | |
| imgsz: int = 1024, | |
| conf: float = 0.25, | |
| iou: float = 0.45, | |
| show: bool = False, | |
| save: bool = False, | |
| save_txt: bool = False, | |
| project: str = "detections", | |
| name: str = "predict", | |
| device: str = "", | |
| ): | |
| """Run inference and return results.""" | |
| model = load_model(model_name) | |
| results = model.predict( | |
| source=source, | |
| imgsz=imgsz, | |
| conf=conf, | |
| iou=iou, | |
| show=show, | |
| save=save, | |
| save_txt=save_txt, | |
| project=project, | |
| name=name, | |
| device=device if device else None, | |
| ) | |
| for i, result in enumerate(results): | |
| n = len(result.boxes) | |
| print(f"[Image {i+1}] {n} detection(s)") | |
| for box in result.boxes: | |
| cls_id = int(box.cls.item()) | |
| cls_name = result.names[cls_id] | |
| conf_val = box.conf.item() | |
| xyxy = [round(v, 1) for v in box.xyxy[0].tolist()] | |
| print(f" {cls_name:15s} conf={conf_val:.2f} bbox={xyxy}") | |
| return results | |
| def compare_modalities( | |
| rgb_source: str, | |
| thermal_source: str, | |
| conf: float = 0.25, | |
| imgsz: int = 1024, | |
| ): | |
| """ | |
| Compare RGB vs thermal detections on co-registered image pairs. | |
| Useful for the matched dataset experiments. | |
| """ | |
| rgb_model = load_model("matched_rgb") | |
| thermal_model = load_model("matched_thermal") | |
| rgb_results = rgb_model.predict(rgb_source, imgsz=imgsz, conf=conf, verbose=False) | |
| thermal_results = thermal_model.predict(thermal_source, imgsz=imgsz, conf=conf, verbose=False) | |
| for i, (r_rgb, r_thm) in enumerate(zip(rgb_results, thermal_results)): | |
| print(f"\n--- Pair {i+1} ---") | |
| print(f" RGB detections: {len(r_rgb.boxes)}") | |
| print(f" Thermal detections: {len(r_thm.boxes)}") | |
| return rgb_results, thermal_results | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Wildlife YOLOv26 Inference") | |
| parser.add_argument( | |
| "--model", | |
| default="rgb", | |
| choices=list(MODELS.keys()) + ["custom"], | |
| help="Model to use. Pass a file path with --model custom --weights <path>.", | |
| ) | |
| parser.add_argument("--weights", default=None, help="Direct path to .pt weights file.") | |
| parser.add_argument("--source", default="0", help="Image/video/folder path or webcam index.") | |
| parser.add_argument("--imgsz", type=int, default=1024, help="Inference image size.") | |
| parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold.") | |
| parser.add_argument("--iou", type=float, default=0.45, help="NMS IoU threshold.") | |
| parser.add_argument("--show", action="store_true", help="Display results.") | |
| parser.add_argument("--save", action="store_true", help="Save annotated images.") | |
| parser.add_argument("--save-txt", action="store_true", help="Save YOLO-format labels.") | |
| parser.add_argument("--project", default="detections", help="Output project folder.") | |
| parser.add_argument("--name", default="predict", help="Output run name.") | |
| parser.add_argument("--device", default="", help="CUDA device, e.g. '0' or 'cpu'.") | |
| parser.add_argument( | |
| "--compare", | |
| nargs=2, | |
| metavar=("RGB_SOURCE", "THERMAL_SOURCE"), | |
| help="Compare RGB and thermal models on co-registered pairs.", | |
| ) | |
| args = parser.parse_args() | |
| if args.compare: | |
| compare_modalities( | |
| rgb_source=args.compare[0], | |
| thermal_source=args.compare[1], | |
| conf=args.conf, | |
| imgsz=args.imgsz, | |
| ) | |
| return | |
| model_name = args.weights if (args.model == "custom" and args.weights) else args.model | |
| run_inference( | |
| model_name=model_name, | |
| source=args.source, | |
| imgsz=args.imgsz, | |
| conf=args.conf, | |
| iou=args.iou, | |
| show=args.show, | |
| save=args.save, | |
| save_txt=args.save_txt, | |
| project=args.project, | |
| name=args.name, | |
| device=args.device, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |