File size: 3,823 Bytes
424fe21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
from transformers import pipeline
from PIL import Image

# Load image classification model
# Using a pre-trained model that can classify various animals and objects
classifier = pipeline("image-classification", model="google/vit-base-patch16-224")

def classify_image(image):
    """
    Classify an uploaded animal image and return top predictions with progress bars
    """
    if image is None:
        return "<div style='font-family: Arial, sans-serif; padding: 20px; text-align: center; color: #666;'>Please upload an image.</div>"
    
    # Classify the image
    results = classifier(image)
    
    # Format results with HTML progress bars - show top 5 predictions
    html_results = "<div style='font-family: Arial, sans-serif;'>"
    html_results += "<h3 style='margin-top: 0;'>Top Predictions:</h3>"
    
    for i, result in enumerate(results[:5], 1):
        label = result['label']
        score = result['score'] * 100
        score_int = int(score)
        
        # Create progress bar with color gradient (green for high, yellow for medium, red for low)
        if score_int >= 70:
            bar_color = "#4CAF50"  # Green
        elif score_int >= 40:
            bar_color = "#FF9800"  # Orange
        else:
            bar_color = "#F44336"  # Red
        
        html_results += f"""
        <div style='margin-bottom: 15px;'>
            <div style='display: flex; justify-content: space-between; margin-bottom: 5px;'>
                <span style='font-weight: bold;'>{i}. {label}</span>
                <span style='font-weight: bold; color: #333;'>{score:.2f}%</span>
            </div>
            <div style='background-color: #e0e0e0; border-radius: 10px; height: 25px; overflow: hidden;'>
                <div style='background-color: {bar_color}; height: 100%; width: {score:.2f}%; transition: width 0.3s ease; border-radius: 10px;'></div>
            </div>
        </div>
        """
    
    html_results += "</div>"
    return html_results

# Create the Gradio interface
with gr.Blocks(title="Animal Image Classifier") as demo:
    gr.Markdown("# Animal Image Classifier")
    gr.Markdown("Upload an animal photo to classify it using AI!")
    
    with gr.Row():
        with gr.Column():
            # Image input
            input_image = gr.Image(
                type="pil",
                label="Upload Animal Photo"
            )
            
            # Classify button
            classify_btn = gr.Button("Classify Image", variant="primary", size="lg")
            clear_btn = gr.Button("Clear", variant="secondary")
        
        with gr.Column():
            # Output for classification results with HTML progress bars
            output_html = gr.HTML(
                label="Classification Results"
            )
    
    # Example images at the bottom
    gr.Markdown("### Example Images")
    gr.Markdown("Try these example images:")
    
    example_images = [
        "cat.png",
        "frog.png",
        "hippo.png",
        "jaguar.png",
        "sloth.png",
        "toucan.png",
        "turtle.png"
    ]
    
    # Create example gallery - images are in the same directory as this script
    import os
    script_dir = os.path.dirname(os.path.abspath(__file__))
    example_paths = [[os.path.join(script_dir, img)] for img in example_images]
    
    gr.Examples(
        examples=example_paths,
        inputs=input_image,
        label="Click on an example image to load it"
    )
    
    # Define button actions
    classify_btn.click(
        fn=classify_image,
        inputs=input_image,
        outputs=output_html
    )
    
    clear_btn.click(
        fn=lambda: (None, "<div></div>"),
        inputs=None,
        outputs=[input_image, output_html]
    )

# Launch the app
if __name__ == "__main__":
    demo.launch()