import os import zipfile from functools import lru_cache from glob import glob import gradio as gr import torch from transformers import pipeline ZIP_PATH = "animal_images.zip" EXAMPLES_DIR = "animal_images" def ensure_examples_extracted(): """ If animal_images.zip exists and animal_images/ does not, extract it. Works both locally and in HF Spaces. """ if os.path.exists(EXAMPLES_DIR): return if os.path.exists(ZIP_PATH): os.makedirs(EXAMPLES_DIR, exist_ok=True) with zipfile.ZipFile(ZIP_PATH, "r") as z: z.extractall(EXAMPLES_DIR) def get_example_image_paths(max_examples: int = 7): """ Returns up to `max_examples` image file paths for Gradio examples. """ ensure_examples_extracted() patterns = ["*.png", "*.jpg", "*.jpeg", "*.webp", "*.bmp"] paths = [] for pat in patterns: paths.extend(glob(os.path.join(EXAMPLES_DIR, "**", pat), recursive=True)) # Keep it stable and limited to 7 paths = sorted(paths)[:max_examples] return paths @lru_cache(maxsize=1) def get_classifier(): """ Load the HF image-classification pipeline once and reuse it. """ device = 0 if torch.cuda.is_available() else -1 # Solid default ImageNet classifier (good for common animals) model_id = "google/vit-base-patch16-224" return pipeline( task="image-classification", model=model_id, device=device ) def classify_image(img): """ img is a PIL Image from gr.Image(type="pil"). Return a dict that gr.Label can render nicely (label -> confidence). """ clf = get_classifier() preds = clf(img, top_k=5) # Convert to {label: score} for gr.Label out = {p["label"]: float(p["score"]) for p in preds} return out def build_demo(): example_paths = get_example_image_paths(7) examples = [[p] for p in example_paths] # safer format for Gradio examples with gr.Blocks() as demo: gr.Markdown( "# Animal Image Classifier\n" "Upload an image (or click an example) to classify it with a pretrained Hugging Face vision model." ) with gr.Row(): inp = gr.Image(type="pil", label="Upload an animal photo") out = gr.Label(num_top_classes=5, label="Predictions (Top 5)") btn = gr.Button("Classify") btn.click(fn=classify_image, inputs=inp, outputs=out) gr.Markdown("## Examples") gr.Examples( examples=examples, inputs=inp, label="Click an example image below" ) return demo if __name__ == "__main__": demo = build_demo() demo.launch()