import cv2 import numpy as np import tensorflow as tf import gradio as gr # ----------------------------- # Load TFLite model (NMS=true) # ----------------------------- MODEL_PATH = "best.tflite" interpreter = tf.lite.Interpreter(model_path=MODEL_PATH) interpreter.allocate_tensors() input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() IMG_SIZE = input_details[0]['shape'][1] # usually 640 CLASS_NAMES = ["buffalo", "elephant", "rhino", "zebra"] # ----------------------------- # Preprocess image # ----------------------------- def preprocess(image): img = cv2.resize(image, (IMG_SIZE, IMG_SIZE)) img = img.astype(np.float32) / 255.0 return np.expand_dims(img, axis=0) # ----------------------------- # Draw bounding boxes # ----------------------------- def draw_boxes(image, boxes, conf_thres=0.2): h, w = image.shape[:2] for box in boxes.reshape(-1, 6): score = box[4] if score < conf_thres: continue x1, y1, x2, y2, _, cls = box cls = int(cls) x1 = int(x1 * w / IMG_SIZE) y1 = int(y1 * h / IMG_SIZE) x2 = int(x2 * w / IMG_SIZE) y2 = int(y2 * h / IMG_SIZE) label = f"{CLASS_NAMES[cls]} {score:.2f}" cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2) cv2.putText(image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2) return image # ----------------------------- # Detection function # ----------------------------- def detect(image, conf_thres=0.2): # Convert PIL image to numpy img_np = np.array(image) input_tensor = preprocess(img_np) # Run inference interpreter.set_tensor(input_details[0]['index'], input_tensor) interpreter.invoke() boxes = interpreter.get_tensor(output_details[0]['index']) # NMS=true # Draw boxes result = draw_boxes(img_np.copy(), boxes, conf_thres) return result # ----------------------------- # Gradio interface # ----------------------------- demo = gr.Interface( fn=detect, inputs=[gr.Image(type="pil", label="Upload Image"), gr.Slider(0, 1, value=0.2, step=0.05, label="Confidence Threshold")], outputs=gr.Image(type="numpy", label="Detection Output"), title="🦁 African Wildlife Detection – YOLO TFLite (NMS=true)", description="Upload an image to detect buffalo, elephant, rhino, and zebra." ) demo.launch(server_name="0.0.0.0", server_port=7860)