File size: 5,480 Bytes
6f74262
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""

Wildlife Detection with YOLOv26 — Inference Script

===================================================

Supports RGB and thermal drone imagery.



Usage:

    python inference.py --model rgb --source path/to/image.jpg

    python inference.py --model thermal_merged --source path/to/thermal/ --save

    python inference.py --model matched_rgb --source image.jpg --conf 0.3 --show



Available models:

    thermal_original  — Baseline thermal model

    thermal_merged    — Refined thermal model (more training data)

    rgb               — Primary RGB model

    matched_rgb       — RGB model trained on matched RGB/thermal pairs

    matched_thermal   — Thermal model trained on matched RGB/thermal pairs

"""

import argparse
from pathlib import Path
from ultralytics import YOLO


MODELS = {
    "thermal_original": "thermal_original/weights/best.pt",
    "thermal_merged":   "thermal_merged/weights/best.pt",
    "rgb":              "rgb/weights/best.pt",
    "matched_rgb":      "matched_rgb/weights/best.pt",
    "matched_thermal":  "matched_thermal/weights/best.pt",
}


def load_model(name: str) -> YOLO:
    """Load a model by name or direct path."""
    path = MODELS.get(name, name)
    print(f"Loading model: {path}")
    return YOLO(path)


def run_inference(

    model_name: str = "rgb",

    source: str = "0",

    imgsz: int = 1024,

    conf: float = 0.25,

    iou: float = 0.45,

    show: bool = False,

    save: bool = False,

    save_txt: bool = False,

    project: str = "detections",

    name: str = "predict",

    device: str = "",

):
    """Run inference and return results."""
    model = load_model(model_name)

    results = model.predict(
        source=source,
        imgsz=imgsz,
        conf=conf,
        iou=iou,
        show=show,
        save=save,
        save_txt=save_txt,
        project=project,
        name=name,
        device=device if device else None,
    )

    for i, result in enumerate(results):
        n = len(result.boxes)
        print(f"[Image {i+1}] {n} detection(s)")
        for box in result.boxes:
            cls_id = int(box.cls.item())
            cls_name = result.names[cls_id]
            conf_val = box.conf.item()
            xyxy = [round(v, 1) for v in box.xyxy[0].tolist()]
            print(f"  {cls_name:15s}  conf={conf_val:.2f}  bbox={xyxy}")

    return results


def compare_modalities(

    rgb_source: str,

    thermal_source: str,

    conf: float = 0.25,

    imgsz: int = 1024,

):
    """

    Compare RGB vs thermal detections on co-registered image pairs.

    Useful for the matched dataset experiments.

    """
    rgb_model     = load_model("matched_rgb")
    thermal_model = load_model("matched_thermal")

    rgb_results     = rgb_model.predict(rgb_source,     imgsz=imgsz, conf=conf, verbose=False)
    thermal_results = thermal_model.predict(thermal_source, imgsz=imgsz, conf=conf, verbose=False)

    for i, (r_rgb, r_thm) in enumerate(zip(rgb_results, thermal_results)):
        print(f"\n--- Pair {i+1} ---")
        print(f"  RGB     detections: {len(r_rgb.boxes)}")
        print(f"  Thermal detections: {len(r_thm.boxes)}")

    return rgb_results, thermal_results


def main():
    parser = argparse.ArgumentParser(description="Wildlife YOLOv26 Inference")
    parser.add_argument(
        "--model",
        default="rgb",
        choices=list(MODELS.keys()) + ["custom"],
        help="Model to use. Pass a file path with --model custom --weights <path>.",
    )
    parser.add_argument("--weights", default=None, help="Direct path to .pt weights file.")
    parser.add_argument("--source", default="0", help="Image/video/folder path or webcam index.")
    parser.add_argument("--imgsz", type=int, default=1024, help="Inference image size.")
    parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold.")
    parser.add_argument("--iou", type=float, default=0.45, help="NMS IoU threshold.")
    parser.add_argument("--show", action="store_true", help="Display results.")
    parser.add_argument("--save", action="store_true", help="Save annotated images.")
    parser.add_argument("--save-txt", action="store_true", help="Save YOLO-format labels.")
    parser.add_argument("--project", default="detections", help="Output project folder.")
    parser.add_argument("--name", default="predict", help="Output run name.")
    parser.add_argument("--device", default="", help="CUDA device, e.g. '0' or 'cpu'.")
    parser.add_argument(
        "--compare",
        nargs=2,
        metavar=("RGB_SOURCE", "THERMAL_SOURCE"),
        help="Compare RGB and thermal models on co-registered pairs.",
    )
    args = parser.parse_args()

    if args.compare:
        compare_modalities(
            rgb_source=args.compare[0],
            thermal_source=args.compare[1],
            conf=args.conf,
            imgsz=args.imgsz,
        )
        return

    model_name = args.weights if (args.model == "custom" and args.weights) else args.model

    run_inference(
        model_name=model_name,
        source=args.source,
        imgsz=args.imgsz,
        conf=args.conf,
        iou=args.iou,
        show=args.show,
        save=args.save,
        save_txt=args.save_txt,
        project=args.project,
        name=args.name,
        device=args.device,
    )


if __name__ == "__main__":
    main()