Animaldetection / app.py
DarshanM0di's picture
Update app.py
8beae3b verified
import cv2
import numpy as np
import tensorflow as tf
import gradio as gr
# -----------------------------
# Load TFLite model (NMS=true)
# -----------------------------
MODEL_PATH = "best.tflite"
interpreter = tf.lite.Interpreter(model_path=MODEL_PATH)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
IMG_SIZE = input_details[0]['shape'][1] # usually 640
CLASS_NAMES = ["buffalo", "elephant", "rhino", "zebra"]
# -----------------------------
# Preprocess image
# -----------------------------
def preprocess(image):
img = cv2.resize(image, (IMG_SIZE, IMG_SIZE))
img = img.astype(np.float32) / 255.0
return np.expand_dims(img, axis=0)
# -----------------------------
# Draw bounding boxes
# -----------------------------
def draw_boxes(image, boxes, conf_thres=0.2):
h, w = image.shape[:2]
for box in boxes.reshape(-1, 6):
score = box[4]
if score < conf_thres:
continue
x1, y1, x2, y2, _, cls = box
cls = int(cls)
x1 = int(x1 * w / IMG_SIZE)
y1 = int(y1 * h / IMG_SIZE)
x2 = int(x2 * w / IMG_SIZE)
y2 = int(y2 * h / IMG_SIZE)
label = f"{CLASS_NAMES[cls]} {score:.2f}"
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(image, label, (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
return image
# -----------------------------
# Detection function
# -----------------------------
def detect(image, conf_thres=0.2):
# Convert PIL image to numpy
img_np = np.array(image)
input_tensor = preprocess(img_np)
# Run inference
interpreter.set_tensor(input_details[0]['index'], input_tensor)
interpreter.invoke()
boxes = interpreter.get_tensor(output_details[0]['index']) # NMS=true
# Draw boxes
result = draw_boxes(img_np.copy(), boxes, conf_thres)
return result
# -----------------------------
# Gradio interface
# -----------------------------
demo = gr.Interface(
fn=detect,
inputs=[gr.Image(type="pil", label="Upload Image"),
gr.Slider(0, 1, value=0.2, step=0.05, label="Confidence Threshold")],
outputs=gr.Image(type="numpy", label="Detection Output"),
title="🦁 African Wildlife Detection – YOLO TFLite (NMS=true)",
description="Upload an image to detect buffalo, elephant, rhino, and zebra."
)
demo.launch(server_name="0.0.0.0", server_port=7860)