File size: 5,480 Bytes
6f74262 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | """
Wildlife Detection with YOLOv26 — Inference Script
===================================================
Supports RGB and thermal drone imagery.
Usage:
python inference.py --model rgb --source path/to/image.jpg
python inference.py --model thermal_merged --source path/to/thermal/ --save
python inference.py --model matched_rgb --source image.jpg --conf 0.3 --show
Available models:
thermal_original — Baseline thermal model
thermal_merged — Refined thermal model (more training data)
rgb — Primary RGB model
matched_rgb — RGB model trained on matched RGB/thermal pairs
matched_thermal — Thermal model trained on matched RGB/thermal pairs
"""
import argparse
from pathlib import Path
from ultralytics import YOLO
MODELS = {
"thermal_original": "thermal_original/weights/best.pt",
"thermal_merged": "thermal_merged/weights/best.pt",
"rgb": "rgb/weights/best.pt",
"matched_rgb": "matched_rgb/weights/best.pt",
"matched_thermal": "matched_thermal/weights/best.pt",
}
def load_model(name: str) -> YOLO:
"""Load a model by name or direct path."""
path = MODELS.get(name, name)
print(f"Loading model: {path}")
return YOLO(path)
def run_inference(
model_name: str = "rgb",
source: str = "0",
imgsz: int = 1024,
conf: float = 0.25,
iou: float = 0.45,
show: bool = False,
save: bool = False,
save_txt: bool = False,
project: str = "detections",
name: str = "predict",
device: str = "",
):
"""Run inference and return results."""
model = load_model(model_name)
results = model.predict(
source=source,
imgsz=imgsz,
conf=conf,
iou=iou,
show=show,
save=save,
save_txt=save_txt,
project=project,
name=name,
device=device if device else None,
)
for i, result in enumerate(results):
n = len(result.boxes)
print(f"[Image {i+1}] {n} detection(s)")
for box in result.boxes:
cls_id = int(box.cls.item())
cls_name = result.names[cls_id]
conf_val = box.conf.item()
xyxy = [round(v, 1) for v in box.xyxy[0].tolist()]
print(f" {cls_name:15s} conf={conf_val:.2f} bbox={xyxy}")
return results
def compare_modalities(
rgb_source: str,
thermal_source: str,
conf: float = 0.25,
imgsz: int = 1024,
):
"""
Compare RGB vs thermal detections on co-registered image pairs.
Useful for the matched dataset experiments.
"""
rgb_model = load_model("matched_rgb")
thermal_model = load_model("matched_thermal")
rgb_results = rgb_model.predict(rgb_source, imgsz=imgsz, conf=conf, verbose=False)
thermal_results = thermal_model.predict(thermal_source, imgsz=imgsz, conf=conf, verbose=False)
for i, (r_rgb, r_thm) in enumerate(zip(rgb_results, thermal_results)):
print(f"\n--- Pair {i+1} ---")
print(f" RGB detections: {len(r_rgb.boxes)}")
print(f" Thermal detections: {len(r_thm.boxes)}")
return rgb_results, thermal_results
def main():
parser = argparse.ArgumentParser(description="Wildlife YOLOv26 Inference")
parser.add_argument(
"--model",
default="rgb",
choices=list(MODELS.keys()) + ["custom"],
help="Model to use. Pass a file path with --model custom --weights <path>.",
)
parser.add_argument("--weights", default=None, help="Direct path to .pt weights file.")
parser.add_argument("--source", default="0", help="Image/video/folder path or webcam index.")
parser.add_argument("--imgsz", type=int, default=1024, help="Inference image size.")
parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold.")
parser.add_argument("--iou", type=float, default=0.45, help="NMS IoU threshold.")
parser.add_argument("--show", action="store_true", help="Display results.")
parser.add_argument("--save", action="store_true", help="Save annotated images.")
parser.add_argument("--save-txt", action="store_true", help="Save YOLO-format labels.")
parser.add_argument("--project", default="detections", help="Output project folder.")
parser.add_argument("--name", default="predict", help="Output run name.")
parser.add_argument("--device", default="", help="CUDA device, e.g. '0' or 'cpu'.")
parser.add_argument(
"--compare",
nargs=2,
metavar=("RGB_SOURCE", "THERMAL_SOURCE"),
help="Compare RGB and thermal models on co-registered pairs.",
)
args = parser.parse_args()
if args.compare:
compare_modalities(
rgb_source=args.compare[0],
thermal_source=args.compare[1],
conf=args.conf,
imgsz=args.imgsz,
)
return
model_name = args.weights if (args.model == "custom" and args.weights) else args.model
run_inference(
model_name=model_name,
source=args.source,
imgsz=args.imgsz,
conf=args.conf,
iou=args.iou,
show=args.show,
save=args.save,
save_txt=args.save_txt,
project=args.project,
name=args.name,
device=args.device,
)
if __name__ == "__main__":
main()
|