601wk3HW / app.py
cespin24's picture
Update app.py
44052f4 verified
import os
import numpy as np
import gradio as gr
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import torch
from pathlib import Path
# ----------------------------
# Model + processor
# ----------------------------
model_name = "google/vit-base-patch16-224"
processor = ViTImageProcessor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name)
device = "mps" if torch.backends.mps.is_available() else "cpu"
model.to(device)
model.eval()
def classify_image(image):
if image is None:
return None
inputs = processor(images=image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.nn.functional.softmax(logits[0], dim=0)
top_prob, top_catid = torch.topk(probs, 3)
return {
model.config.id2label[top_catid[i].item()]: float(top_prob[i].item())
for i in range(top_prob.size(0))
}
# ----------------------------
# Robust example discovery
# ----------------------------
SCRIPT_DIR = Path(__file__).resolve().parent
CANDIDATE_DIRS = [
SCRIPT_DIR / "animal_images", # preferred
SCRIPT_DIR, # fallback: repo root
]
# pick the first directory that contains at least one png
image_dir = None
for d in CANDIDATE_DIRS:
if d.exists() and any(d.glob("*.png")):
image_dir = d
break
print("[DEBUG] SCRIPT_DIR:", SCRIPT_DIR)
print("[DEBUG] Candidate dirs:", [str(d) for d in CANDIDATE_DIRS])
print("[DEBUG] Chosen image_dir:", image_dir)
def load_example(path: Path):
img = Image.open(path).convert("RGB")
return np.array(img)
# If we found a directory, load all PNGs in it as examples
examples = []
if image_dir is not None:
pngs = sorted(list(image_dir.glob("*.png")) + list(image_dir.glob("*.PNG")))
print("[DEBUG] PNG files found:", [p.name for p in pngs])
examples = [[load_example(p)] for p in pngs]
else:
print("[DEBUG] No PNGs found in either animal_images/ or repo root.")
demo = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=gr.Label(num_top_classes=3, label="Predictions"),
examples=examples,
cache_examples=False,
title="Animal Classifier",
description="Upload an image of an animal (or select an example) to classify it using the Google ViT model."
)
if __name__ == "__main__":
# If we found images, allow Gradio to access that folder (safe for HF)
allowed = [str(image_dir)] if image_dir is not None else None
demo.launch(allowed_paths=allowed)