File size: 2,377 Bytes
aac8470
 
 
 
 
8047dc3
 
395b902
8047dc3
aac8470
 
 
278cdaa
8047dc3
278cdaa
aac8470
 
278cdaa
 
aac8470
 
8047dc3
278cdaa
aac8470
 
f996c5a
278cdaa
 
 
aac8470
 
278cdaa
 
 
 
 
 
6614d3d
278cdaa
 
 
aac8470
278cdaa
 
 
aac8470
278cdaa
 
aac8470
278cdaa
 
 
aac8470
278cdaa
 
 
 
 
aac8470
278cdaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""
Animal Type Classification App
A robust Gradio application for classifying animals using YOLOv8
"""

import gradio as gr
from ultralytics import YOLO
from PIL import Image
import numpy as np
import logging
import sys
import os
from typing import Optional

# Logging Configuration
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)

# Configuration
MODEL_PATH = "best_animal_classifier.pt"
CLASS_NAMES = ["butterfly", "chicken", "elephant", "horse", "spider", "squirrel"]

# Load the model
try:
    if os.path.exists(MODEL_PATH):
        model = YOLO(MODEL_PATH)
        logger.info("✅ Model loaded successfully!")
    else:
        logger.error(f"❌ Model file not found at {MODEL_PATH}")
        model = None
except Exception as e:
    logger.error(f"❌ Error loading model: {e}")
    model = None

def classify_animal(image):
    if image is None:
        return "Please upload an image."
    
    if model is None:
        return "Model not loaded. Check server logs."

    try:
        # Run inference
        results = model(image)
        
        # YOLOv8 classification returns a list of results
        # We take the top prediction from the first result
        result = results[0]
        
        if result.probs is not None:
            # Get index of the highest probability
            top1_idx = result.probs.top1
            conf = result.probs.top1conf.item()
            label = result.names[top1_idx]
            
            return f"Prediction: {label.upper()} ({conf:.2%})"
        else:
            return "No animals detected or classification failed."

    except Exception as e:
        logger.error(f"Inference error: {e}")
        return f"Error during classification: {str(e)}"

# Gradio Interface
demo = gr.Interface(
    fn=classify_animal,
    inputs=gr.Image(type="pil", label="Upload Animal Image"),
    outputs=gr.Textbox(label="Result"),
    title="🐾 Animal Type Classifier",
    description="Upload a photo of a butterfly, chicken, elephant, horse, spider, or squirrel to identify it.",
    examples=[["example_elephant.jpg"]] if os.path.exists("example_elephant.jpg") else None,
    cache_examples=False
)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)