Spaces:
Running
Running
File size: 2,761 Bytes
412cf06 4fc5a19 ff45b18 f3969d0 ff45b18 f3969d0 fd25d05 f3969d0 fd25d05 f3969d0 fd25d05 f3969d0 412cf06 7cea9f7 412cf06 7cea9f7 412cf06 7cea9f7 f3969d0 f728991 4fc5a19 f728991 f3969d0 7cea9f7 f3969d0 f728991 7cea9f7 4fc5a19 7cea9f7 f3969d0 7cea9f7 f3969d0 7cea9f7 412cf06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import gradio as gr
import torch
from transformers import ViTForImageClassification
import os
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from torchvision import transforms
import json
# Download the file from your model repo (replace with your actual token if private)
model_path = hf_hub_download(
repo_id="patcdaniel/UCSCPhytoViT83",
filename="model.safetensors",
token=os.environ.get("HF_TOKEN") # omit this line if public
)
state_dict = load_file(model_path)
model = ViTForImageClassification.from_pretrained(
"google/vit-base-patch16-224-in21k",
num_labels=83 # this must match your training
)
model.load_state_dict(state_dict)
model_path = hf_hub_download(
repo_id="patcdaniel/UCSCPhytoViT83",
filename="label_names.json",
token=os.environ.get("HF_TOKEN"),
local_dir="."
)
# Load class label dictionary (label -> index)
with open(model_path, "r") as f:
id2label = {int(k): v for k, v in json.load(f).items()}
# Convert to id -> label
# id2label = {v: k for k, v in label2id.items()}
# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Inference function
def predict(image):
try:
transform = transforms.Compose([
transforms.Resize((224, 224)), # match ViT input size
transforms.ToTensor(), # Converts PIL.Image to torch.Tensor
transforms.Normalize(mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225))
])
pixel_values = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(pixel_values).logits
probs = torch.nn.functional.softmax(logits, dim=1).squeeze()
topk = torch.topk(probs, k=3)
top_indices = topk.indices.tolist()
top_scores = topk.values.tolist()
top_labels = [id2label[i] for i in top_indices]
return {label: round(score, 4) for label, score in zip(top_labels, top_scores)}
except Exception as e:
import traceback
print(traceback.format_exc())
return {"Error": str(e)}
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# PhytoViT - IFCB Phytoplankton Classifier")
gr.Markdown("Upload an image or paste a URL. Model: `phytoViT_508k_20250611`")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image")
url_input = gr.Textbox(label="...or paste image URL")
predict_btn = gr.Button("Classify")
label_output = gr.Label(label="Top 5 Predictions")
predict_btn.click(fn=predict, inputs=image_input, outputs=label_output)
demo.launch() |