Patrick Daniel
Fixed Transform
4fc5a19
import gradio as gr
import torch
from transformers import ViTForImageClassification
import os
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from torchvision import transforms
import json
# Download the file from your model repo (replace with your actual token if private)
model_path = hf_hub_download(
repo_id="patcdaniel/UCSCPhytoViT83",
filename="model.safetensors",
token=os.environ.get("HF_TOKEN") # omit this line if public
)
state_dict = load_file(model_path)
model = ViTForImageClassification.from_pretrained(
"google/vit-base-patch16-224-in21k",
num_labels=83 # this must match your training
)
model.load_state_dict(state_dict)
model_path = hf_hub_download(
repo_id="patcdaniel/UCSCPhytoViT83",
filename="label_names.json",
token=os.environ.get("HF_TOKEN"),
local_dir="."
)
# Load class label dictionary (label -> index)
with open(model_path, "r") as f:
id2label = {int(k): v for k, v in json.load(f).items()}
# Convert to id -> label
# id2label = {v: k for k, v in label2id.items()}
# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Inference function
def predict(image):
try:
transform = transforms.Compose([
transforms.Resize((224, 224)), # match ViT input size
transforms.ToTensor(), # Converts PIL.Image to torch.Tensor
transforms.Normalize(mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225))
])
pixel_values = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(pixel_values).logits
probs = torch.nn.functional.softmax(logits, dim=1).squeeze()
topk = torch.topk(probs, k=3)
top_indices = topk.indices.tolist()
top_scores = topk.values.tolist()
top_labels = [id2label[i] for i in top_indices]
return {label: round(score, 4) for label, score in zip(top_labels, top_scores)}
except Exception as e:
import traceback
print(traceback.format_exc())
return {"Error": str(e)}
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# PhytoViT - IFCB Phytoplankton Classifier")
gr.Markdown("Upload an image or paste a URL. Model: `phytoViT_508k_20250611`")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image")
url_input = gr.Textbox(label="...or paste image URL")
predict_btn = gr.Button("Classify")
label_output = gr.Label(label="Top 5 Predictions")
predict_btn.click(fn=predict, inputs=image_input, outputs=label_output)
demo.launch()