Spaces:
Sleeping
Sleeping
| import cv2 | |
| import numpy as np | |
| import tensorflow as tf | |
| import gradio as gr | |
| # ----------------------------- | |
| # Load TFLite model (NMS=true) | |
| # ----------------------------- | |
| MODEL_PATH = "best.tflite" | |
| interpreter = tf.lite.Interpreter(model_path=MODEL_PATH) | |
| interpreter.allocate_tensors() | |
| input_details = interpreter.get_input_details() | |
| output_details = interpreter.get_output_details() | |
| IMG_SIZE = input_details[0]['shape'][1] # usually 640 | |
| CLASS_NAMES = ["buffalo", "elephant", "rhino", "zebra"] | |
| # ----------------------------- | |
| # Preprocess image | |
| # ----------------------------- | |
| def preprocess(image): | |
| img = cv2.resize(image, (IMG_SIZE, IMG_SIZE)) | |
| img = img.astype(np.float32) / 255.0 | |
| return np.expand_dims(img, axis=0) | |
| # ----------------------------- | |
| # Draw bounding boxes | |
| # ----------------------------- | |
| def draw_boxes(image, boxes, conf_thres=0.2): | |
| h, w = image.shape[:2] | |
| for box in boxes.reshape(-1, 6): | |
| score = box[4] | |
| if score < conf_thres: | |
| continue | |
| x1, y1, x2, y2, _, cls = box | |
| cls = int(cls) | |
| x1 = int(x1 * w / IMG_SIZE) | |
| y1 = int(y1 * h / IMG_SIZE) | |
| x2 = int(x2 * w / IMG_SIZE) | |
| y2 = int(y2 * h / IMG_SIZE) | |
| label = f"{CLASS_NAMES[cls]} {score:.2f}" | |
| cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2) | |
| cv2.putText(image, label, (x1, y1 - 10), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2) | |
| return image | |
| # ----------------------------- | |
| # Detection function | |
| # ----------------------------- | |
| def detect(image, conf_thres=0.2): | |
| # Convert PIL image to numpy | |
| img_np = np.array(image) | |
| input_tensor = preprocess(img_np) | |
| # Run inference | |
| interpreter.set_tensor(input_details[0]['index'], input_tensor) | |
| interpreter.invoke() | |
| boxes = interpreter.get_tensor(output_details[0]['index']) # NMS=true | |
| # Draw boxes | |
| result = draw_boxes(img_np.copy(), boxes, conf_thres) | |
| return result | |
| # ----------------------------- | |
| # Gradio interface | |
| # ----------------------------- | |
| demo = gr.Interface( | |
| fn=detect, | |
| inputs=[gr.Image(type="pil", label="Upload Image"), | |
| gr.Slider(0, 1, value=0.2, step=0.05, label="Confidence Threshold")], | |
| outputs=gr.Image(type="numpy", label="Detection Output"), | |
| title="🦁 African Wildlife Detection – YOLO TFLite (NMS=true)", | |
| description="Upload an image to detect buffalo, elephant, rhino, and zebra." | |
| ) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |