WK3 / app.py
MiriamCamachodQ's picture
Upload 8 files
eb3cfa7 verified
import gradio as gr
from transformers import pipeline
# Initialize the image classification pipeline with the Google ViT model
# This model is trained on ImageNet-1k and works well for general animal classification
classifier = pipeline("image-classification", model="google/vit-base-patch16-224")
def classify_image(image):
"""
Takes an image, runs it through the classifier, and returns the top predictions.
"""
if image is None:
return None
predictions = classifier(image)
# Gradio's Label component expects a dictionary of {label: confidence}
return {p["label"]: p["score"] for p in predictions}
# List of example images located in the nested animal_images folder
# Note: Based on your folder structure, the images are in 'animal_images/animal_images/'
examples = [
["animal_images/animal_images/cat.png"],
["animal_images/animal_images/frog.png"],
["animal_images/animal_images/hippo.png"],
["animal_images/animal_images/jaguar.png"],
["animal_images/animal_images/sloth.png"],
["animal_images/animal_images/toucan.png"],
["animal_images/animal_images/turtle.png"]
]
# Create the Gradio Interface
# We map the inputs and outputs to match the labels seen in your screenshot ("inp", "output")
demo = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil", label="inp"),
outputs=gr.Label(num_top_classes=3, label="output"),
examples=examples,
title="Animal Classifier",
description="Upload an image of an animal to classify it using the Google ViT model."
)
if __name__ == "__main__":
demo.launch()