from ultralytics import YOLO import cv2 import gradio as gr import numpy as np # ------------------------- # Load models (once) # ------------------------- det_model = YOLO("models/detect/best_yolov8s.onnx") buck_doe_model = YOLO( "models/classify/Buck_classification_epoch_26_best.onnx", task="classify" ) mule_whitetail_model = YOLO( "models/classify/mule_vs_whitetail.onnx", task="classify" ) # ------------------------- # Inference function # ------------------------- def predict(image): """ image: input image (numpy array) """ # Convert RGB (Gradio) → BGR (OpenCV) image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # Run detection det_results = det_model(image) for r in det_results: for box in r.boxes: x1, y1, x2, y2 = map(int, box.xyxy[0]) crop = image[y1:y2, x1:x2] if crop.size == 0: continue # First-stage classification: Buck/Doe buck_results = buck_doe_model(crop) buck_probs = buck_results[0].probs buck_id = buck_probs.top1 buck_name = buck_results[0].names[buck_id] # Build label if buck_name.lower() == "buck": # Second-stage classification: Mule/Whitetail mule_results = mule_whitetail_model(crop) mule_probs = mule_results[0].probs mule_id = mule_probs.top1 mule_name = mule_results[0].names[mule_id] label = f"Deer | Buck | {mule_name}" else: label = "Deer | Doe" # Draw box + label cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2) cv2.putText( image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2 ) # Convert back BGR → RGB image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) return image # ------------------------- # Gradio UI # ------------------------- app = gr.Interface( fn=predict, inputs=gr.Image(type="numpy", label="Upload Deer Image"), outputs=gr.Image(type="numpy", label="Prediction"), title="Buck Tracker AI – Deer Detection & Classification", description=( "Upload a trail camera image. The system detects deer and classifies them as:\n" "- Deer | Doe\n" "- Deer | Buck | Mule or Whitetail" ) ) # ------------------------- # Launch # ------------------------- if __name__ == "__main__": app.launch()