Patrick Daniel
Initial commit
412cf06
raw
history blame
1.49 kB
import gradio as gr
import torch
from PIL import Image
from transformers import ViTForImageClassification, ViTImageProcessor
import requests
import json
# Load model and processor from Hugging Face Hub
model = ViTForImageClassification.from_pretrained("patcdaniel/phytoViT_508k_20250611")
processor = ViTImageProcessor.from_pretrained("patcdaniel/phytoViT_508k_20250611")
model.eval()
# Load class labels from hosted file
LABELS_URL = "https://huggingface.co/patcdaniel/phytoViT_508k_20250611/resolve/main/label_names.json"
class_labels = requests.get(LABELS_URL).json()
def predict(image):
image = image.convert("RGB")
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.nn.functional.softmax(logits, dim=-1).squeeze()
# Get top 2 predictions
topk = torch.topk(probs, k=2)
top_scores = topk.values.tolist()
top_labels = [class_labels[i] for i in topk.indices.tolist()]
# Format output
output = {label: round(score, 4) for label, score in zip(top_labels, top_scores)}
return output
# Gradio interface
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload or Paste an Image"),
outputs=gr.Label(num_top_classes=2, label="Top Predictions"),
title="PhytoViT Classifier",
description="Upload an IFCB phytoplankton image or paste an image URL to classify it using a ViT model trained on 508k examples."
)
demo.launch()