French02947's picture
Upload 9 files
169b2d6 verified
import os
import zipfile
from functools import lru_cache
from glob import glob
import gradio as gr
import torch
from transformers import pipeline
ZIP_PATH = "animal_images.zip"
EXAMPLES_DIR = "animal_images"
def ensure_examples_extracted():
"""
If animal_images.zip exists and animal_images/ does not, extract it.
Works both locally and in HF Spaces.
"""
if os.path.exists(EXAMPLES_DIR):
return
if os.path.exists(ZIP_PATH):
os.makedirs(EXAMPLES_DIR, exist_ok=True)
with zipfile.ZipFile(ZIP_PATH, "r") as z:
z.extractall(EXAMPLES_DIR)
def get_example_image_paths(max_examples: int = 7):
"""
Returns up to `max_examples` image file paths for Gradio examples.
"""
ensure_examples_extracted()
patterns = ["*.png", "*.jpg", "*.jpeg", "*.webp", "*.bmp"]
paths = []
for pat in patterns:
paths.extend(glob(os.path.join(EXAMPLES_DIR, "**", pat), recursive=True))
# Keep it stable and limited to 7
paths = sorted(paths)[:max_examples]
return paths
@lru_cache(maxsize=1)
def get_classifier():
"""
Load the HF image-classification pipeline once and reuse it.
"""
device = 0 if torch.cuda.is_available() else -1
# Solid default ImageNet classifier (good for common animals)
model_id = "google/vit-base-patch16-224"
return pipeline(
task="image-classification",
model=model_id,
device=device
)
def classify_image(img):
"""
img is a PIL Image from gr.Image(type="pil").
Return a dict that gr.Label can render nicely (label -> confidence).
"""
clf = get_classifier()
preds = clf(img, top_k=5)
# Convert to {label: score} for gr.Label
out = {p["label"]: float(p["score"]) for p in preds}
return out
def build_demo():
example_paths = get_example_image_paths(7)
examples = [[p] for p in example_paths] # safer format for Gradio examples
with gr.Blocks() as demo:
gr.Markdown(
"# Animal Image Classifier\n"
"Upload an image (or click an example) to classify it with a pretrained Hugging Face vision model."
)
with gr.Row():
inp = gr.Image(type="pil", label="Upload an animal photo")
out = gr.Label(num_top_classes=5, label="Predictions (Top 5)")
btn = gr.Button("Classify")
btn.click(fn=classify_image, inputs=inp, outputs=out)
gr.Markdown("## Examples")
gr.Examples(
examples=examples,
inputs=inp,
label="Click an example image below"
)
return demo
if __name__ == "__main__":
demo = build_demo()
demo.launch()