Spaces:
Sleeping
Sleeping
| import os | |
| import zipfile | |
| from functools import lru_cache | |
| from glob import glob | |
| import gradio as gr | |
| import torch | |
| from transformers import pipeline | |
| ZIP_PATH = "animal_images.zip" | |
| EXAMPLES_DIR = "animal_images" | |
| def ensure_examples_extracted(): | |
| """ | |
| If animal_images.zip exists and animal_images/ does not, extract it. | |
| Works both locally and in HF Spaces. | |
| """ | |
| if os.path.exists(EXAMPLES_DIR): | |
| return | |
| if os.path.exists(ZIP_PATH): | |
| os.makedirs(EXAMPLES_DIR, exist_ok=True) | |
| with zipfile.ZipFile(ZIP_PATH, "r") as z: | |
| z.extractall(EXAMPLES_DIR) | |
| def get_example_image_paths(max_examples: int = 7): | |
| """ | |
| Returns up to `max_examples` image file paths for Gradio examples. | |
| """ | |
| ensure_examples_extracted() | |
| patterns = ["*.png", "*.jpg", "*.jpeg", "*.webp", "*.bmp"] | |
| paths = [] | |
| for pat in patterns: | |
| paths.extend(glob(os.path.join(EXAMPLES_DIR, "**", pat), recursive=True)) | |
| # Keep it stable and limited to 7 | |
| paths = sorted(paths)[:max_examples] | |
| return paths | |
| def get_classifier(): | |
| """ | |
| Load the HF image-classification pipeline once and reuse it. | |
| """ | |
| device = 0 if torch.cuda.is_available() else -1 | |
| # Solid default ImageNet classifier (good for common animals) | |
| model_id = "google/vit-base-patch16-224" | |
| return pipeline( | |
| task="image-classification", | |
| model=model_id, | |
| device=device | |
| ) | |
| def classify_image(img): | |
| """ | |
| img is a PIL Image from gr.Image(type="pil"). | |
| Return a dict that gr.Label can render nicely (label -> confidence). | |
| """ | |
| clf = get_classifier() | |
| preds = clf(img, top_k=5) | |
| # Convert to {label: score} for gr.Label | |
| out = {p["label"]: float(p["score"]) for p in preds} | |
| return out | |
| def build_demo(): | |
| example_paths = get_example_image_paths(7) | |
| examples = [[p] for p in example_paths] # safer format for Gradio examples | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| "# Animal Image Classifier\n" | |
| "Upload an image (or click an example) to classify it with a pretrained Hugging Face vision model." | |
| ) | |
| with gr.Row(): | |
| inp = gr.Image(type="pil", label="Upload an animal photo") | |
| out = gr.Label(num_top_classes=5, label="Predictions (Top 5)") | |
| btn = gr.Button("Classify") | |
| btn.click(fn=classify_image, inputs=inp, outputs=out) | |
| gr.Markdown("## Examples") | |
| gr.Examples( | |
| examples=examples, | |
| inputs=inp, | |
| label="Click an example image below" | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_demo() | |
| demo.launch() |