File size: 2,794 Bytes
169b2d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import zipfile
from functools import lru_cache
from glob import glob

import gradio as gr
import torch
from transformers import pipeline


ZIP_PATH = "animal_images.zip"
EXAMPLES_DIR = "animal_images"


def ensure_examples_extracted():
    """

    If animal_images.zip exists and animal_images/ does not, extract it.

    Works both locally and in HF Spaces.

    """
    if os.path.exists(EXAMPLES_DIR):
        return

    if os.path.exists(ZIP_PATH):
        os.makedirs(EXAMPLES_DIR, exist_ok=True)
        with zipfile.ZipFile(ZIP_PATH, "r") as z:
            z.extractall(EXAMPLES_DIR)


def get_example_image_paths(max_examples: int = 7):
    """

    Returns up to `max_examples` image file paths for Gradio examples.

    """
    ensure_examples_extracted()

    patterns = ["*.png", "*.jpg", "*.jpeg", "*.webp", "*.bmp"]
    paths = []
    for pat in patterns:
        paths.extend(glob(os.path.join(EXAMPLES_DIR, "**", pat), recursive=True))

    # Keep it stable and limited to 7
    paths = sorted(paths)[:max_examples]
    return paths


@lru_cache(maxsize=1)
def get_classifier():
    """

    Load the HF image-classification pipeline once and reuse it.

    """
    device = 0 if torch.cuda.is_available() else -1

    # Solid default ImageNet classifier (good for common animals)
    model_id = "google/vit-base-patch16-224"

    return pipeline(
        task="image-classification",
        model=model_id,
        device=device
    )


def classify_image(img):
    """

    img is a PIL Image from gr.Image(type="pil").

    Return a dict that gr.Label can render nicely (label -> confidence).

    """
    clf = get_classifier()
    preds = clf(img, top_k=5)

    # Convert to {label: score} for gr.Label
    out = {p["label"]: float(p["score"]) for p in preds}
    return out


def build_demo():
    example_paths = get_example_image_paths(7)
    examples = [[p] for p in example_paths]  # safer format for Gradio examples

    with gr.Blocks() as demo:
        gr.Markdown(
            "# Animal Image Classifier\n"
            "Upload an image (or click an example) to classify it with a pretrained Hugging Face vision model."
        )

        with gr.Row():
            inp = gr.Image(type="pil", label="Upload an animal photo")
            out = gr.Label(num_top_classes=5, label="Predictions (Top 5)")

        btn = gr.Button("Classify")

        btn.click(fn=classify_image, inputs=inp, outputs=out)

        gr.Markdown("## Examples")
        gr.Examples(
            examples=examples,
            inputs=inp,
            label="Click an example image below"
        )

    return demo


if __name__ == "__main__":
    demo = build_demo()
    demo.launch()