bambi-models / inference.py
cpraschl's picture
Upload inference.py with huggingface_hub
6f74262 verified
"""
Wildlife Detection with YOLOv26 — Inference Script
===================================================
Supports RGB and thermal drone imagery.
Usage:
python inference.py --model rgb --source path/to/image.jpg
python inference.py --model thermal_merged --source path/to/thermal/ --save
python inference.py --model matched_rgb --source image.jpg --conf 0.3 --show
Available models:
thermal_original — Baseline thermal model
thermal_merged — Refined thermal model (more training data)
rgb — Primary RGB model
matched_rgb — RGB model trained on matched RGB/thermal pairs
matched_thermal — Thermal model trained on matched RGB/thermal pairs
"""
import argparse
from pathlib import Path
from ultralytics import YOLO
MODELS = {
"thermal_original": "thermal_original/weights/best.pt",
"thermal_merged": "thermal_merged/weights/best.pt",
"rgb": "rgb/weights/best.pt",
"matched_rgb": "matched_rgb/weights/best.pt",
"matched_thermal": "matched_thermal/weights/best.pt",
}
def load_model(name: str) -> YOLO:
"""Load a model by name or direct path."""
path = MODELS.get(name, name)
print(f"Loading model: {path}")
return YOLO(path)
def run_inference(
model_name: str = "rgb",
source: str = "0",
imgsz: int = 1024,
conf: float = 0.25,
iou: float = 0.45,
show: bool = False,
save: bool = False,
save_txt: bool = False,
project: str = "detections",
name: str = "predict",
device: str = "",
):
"""Run inference and return results."""
model = load_model(model_name)
results = model.predict(
source=source,
imgsz=imgsz,
conf=conf,
iou=iou,
show=show,
save=save,
save_txt=save_txt,
project=project,
name=name,
device=device if device else None,
)
for i, result in enumerate(results):
n = len(result.boxes)
print(f"[Image {i+1}] {n} detection(s)")
for box in result.boxes:
cls_id = int(box.cls.item())
cls_name = result.names[cls_id]
conf_val = box.conf.item()
xyxy = [round(v, 1) for v in box.xyxy[0].tolist()]
print(f" {cls_name:15s} conf={conf_val:.2f} bbox={xyxy}")
return results
def compare_modalities(
rgb_source: str,
thermal_source: str,
conf: float = 0.25,
imgsz: int = 1024,
):
"""
Compare RGB vs thermal detections on co-registered image pairs.
Useful for the matched dataset experiments.
"""
rgb_model = load_model("matched_rgb")
thermal_model = load_model("matched_thermal")
rgb_results = rgb_model.predict(rgb_source, imgsz=imgsz, conf=conf, verbose=False)
thermal_results = thermal_model.predict(thermal_source, imgsz=imgsz, conf=conf, verbose=False)
for i, (r_rgb, r_thm) in enumerate(zip(rgb_results, thermal_results)):
print(f"\n--- Pair {i+1} ---")
print(f" RGB detections: {len(r_rgb.boxes)}")
print(f" Thermal detections: {len(r_thm.boxes)}")
return rgb_results, thermal_results
def main():
parser = argparse.ArgumentParser(description="Wildlife YOLOv26 Inference")
parser.add_argument(
"--model",
default="rgb",
choices=list(MODELS.keys()) + ["custom"],
help="Model to use. Pass a file path with --model custom --weights <path>.",
)
parser.add_argument("--weights", default=None, help="Direct path to .pt weights file.")
parser.add_argument("--source", default="0", help="Image/video/folder path or webcam index.")
parser.add_argument("--imgsz", type=int, default=1024, help="Inference image size.")
parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold.")
parser.add_argument("--iou", type=float, default=0.45, help="NMS IoU threshold.")
parser.add_argument("--show", action="store_true", help="Display results.")
parser.add_argument("--save", action="store_true", help="Save annotated images.")
parser.add_argument("--save-txt", action="store_true", help="Save YOLO-format labels.")
parser.add_argument("--project", default="detections", help="Output project folder.")
parser.add_argument("--name", default="predict", help="Output run name.")
parser.add_argument("--device", default="", help="CUDA device, e.g. '0' or 'cpu'.")
parser.add_argument(
"--compare",
nargs=2,
metavar=("RGB_SOURCE", "THERMAL_SOURCE"),
help="Compare RGB and thermal models on co-registered pairs.",
)
args = parser.parse_args()
if args.compare:
compare_modalities(
rgb_source=args.compare[0],
thermal_source=args.compare[1],
conf=args.conf,
imgsz=args.imgsz,
)
return
model_name = args.weights if (args.model == "custom" and args.weights) else args.model
run_inference(
model_name=model_name,
source=args.source,
imgsz=args.imgsz,
conf=args.conf,
iou=args.iou,
show=args.show,
save=args.save,
save_txt=args.save_txt,
project=args.project,
name=args.name,
device=args.device,
)
if __name__ == "__main__":
main()