| import os |
| import numpy as np |
| import gradio as gr |
| from transformers import ViTImageProcessor, ViTForImageClassification |
| from PIL import Image |
| import torch |
| from pathlib import Path |
|
|
| |
| |
| |
| model_name = "google/vit-base-patch16-224" |
| processor = ViTImageProcessor.from_pretrained(model_name) |
| model = ViTForImageClassification.from_pretrained(model_name) |
|
|
| device = "mps" if torch.backends.mps.is_available() else "cpu" |
| model.to(device) |
| model.eval() |
|
|
| def classify_image(image): |
| if image is None: |
| return None |
| inputs = processor(images=image, return_tensors="pt") |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
| with torch.no_grad(): |
| logits = model(**inputs).logits |
|
|
| probs = torch.nn.functional.softmax(logits[0], dim=0) |
| top_prob, top_catid = torch.topk(probs, 3) |
|
|
| return { |
| model.config.id2label[top_catid[i].item()]: float(top_prob[i].item()) |
| for i in range(top_prob.size(0)) |
| } |
|
|
| |
| |
| |
| SCRIPT_DIR = Path(__file__).resolve().parent |
|
|
| CANDIDATE_DIRS = [ |
| SCRIPT_DIR / "animal_images", |
| SCRIPT_DIR, |
| ] |
|
|
| |
| image_dir = None |
| for d in CANDIDATE_DIRS: |
| if d.exists() and any(d.glob("*.png")): |
| image_dir = d |
| break |
|
|
| print("[DEBUG] SCRIPT_DIR:", SCRIPT_DIR) |
| print("[DEBUG] Candidate dirs:", [str(d) for d in CANDIDATE_DIRS]) |
| print("[DEBUG] Chosen image_dir:", image_dir) |
|
|
| def load_example(path: Path): |
| img = Image.open(path).convert("RGB") |
| return np.array(img) |
|
|
| |
| examples = [] |
| if image_dir is not None: |
| pngs = sorted(list(image_dir.glob("*.png")) + list(image_dir.glob("*.PNG"))) |
| print("[DEBUG] PNG files found:", [p.name for p in pngs]) |
| examples = [[load_example(p)] for p in pngs] |
| else: |
| print("[DEBUG] No PNGs found in either animal_images/ or repo root.") |
|
|
| demo = gr.Interface( |
| fn=classify_image, |
| inputs=gr.Image(type="pil", label="Upload Image"), |
| outputs=gr.Label(num_top_classes=3, label="Predictions"), |
| examples=examples, |
| cache_examples=False, |
| title="Animal Classifier", |
| description="Upload an image of an animal (or select an example) to classify it using the Google ViT model." |
| ) |
|
|
| if __name__ == "__main__": |
| |
| allowed = [str(image_dir)] if image_dir is not None else None |
| demo.launch(allowed_paths=allowed) |
|
|