codewithRiz's picture
Update app.py
67ee0f6 verified
from ultralytics import YOLO
import cv2
import gradio as gr
import numpy as np
# -------------------------
# Load models (once)
# -------------------------
det_model = YOLO("models/detect/best_yolov8s.onnx")
buck_doe_model = YOLO(
"models/classify/Buck_classification_epoch_26_best.onnx",
task="classify"
)
mule_whitetail_model = YOLO(
"models/classify/mule_vs_whitetail.onnx",
task="classify"
)
# -------------------------
# Inference function
# -------------------------
def predict(image):
"""
image: input image (numpy array)
"""
# Convert RGB (Gradio) β†’ BGR (OpenCV)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# Run detection
det_results = det_model(image)
for r in det_results:
for box in r.boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0])
crop = image[y1:y2, x1:x2]
if crop.size == 0:
continue
# First-stage classification: Buck/Doe
buck_results = buck_doe_model(crop)
buck_probs = buck_results[0].probs
buck_id = buck_probs.top1
buck_name = buck_results[0].names[buck_id]
# Build label
if buck_name.lower() == "buck":
# Second-stage classification: Mule/Whitetail
mule_results = mule_whitetail_model(crop)
mule_probs = mule_results[0].probs
mule_id = mule_probs.top1
mule_name = mule_results[0].names[mule_id]
label = f"Deer | Buck | {mule_name}"
else:
label = "Deer | Doe"
# Draw box + label
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(
image,
label,
(x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX,
0.7,
(0, 255, 0),
2
)
# Convert back BGR β†’ RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image
# -------------------------
# Gradio UI
# -------------------------
app = gr.Interface(
fn=predict,
inputs=gr.Image(type="numpy", label="Upload Deer Image"),
outputs=gr.Image(type="numpy", label="Prediction"),
title="Buck Tracker AI – Deer Detection & Classification",
description=(
"Upload a trail camera image. The system detects deer and classifies them as:\n"
"- Deer | Doe\n"
"- Deer | Buck | Mule or Whitetail"
)
)
# -------------------------
# Launch
# -------------------------
if __name__ == "__main__":
app.launch()