import gradio as gr import torch from transformers import ViTForImageClassification import os from safetensors.torch import load_file from huggingface_hub import hf_hub_download from torchvision import transforms import json # Download the file from your model repo (replace with your actual token if private) model_path = hf_hub_download( repo_id="patcdaniel/UCSCPhytoViT83", filename="model.safetensors", token=os.environ.get("HF_TOKEN") # omit this line if public ) state_dict = load_file(model_path) model = ViTForImageClassification.from_pretrained( "google/vit-base-patch16-224-in21k", num_labels=83 # this must match your training ) model.load_state_dict(state_dict) model_path = hf_hub_download( repo_id="patcdaniel/UCSCPhytoViT83", filename="label_names.json", token=os.environ.get("HF_TOKEN"), local_dir="." ) # Load class label dictionary (label -> index) with open(model_path, "r") as f: id2label = {int(k): v for k, v in json.load(f).items()} # Convert to id -> label # id2label = {v: k for k, v in label2id.items()} # Use GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Inference function def predict(image): try: transform = transforms.Compose([ transforms.Resize((224, 224)), # match ViT input size transforms.ToTensor(), # Converts PIL.Image to torch.Tensor transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) ]) pixel_values = transform(image).unsqueeze(0).to(device) with torch.no_grad(): logits = model(pixel_values).logits probs = torch.nn.functional.softmax(logits, dim=1).squeeze() topk = torch.topk(probs, k=3) top_indices = topk.indices.tolist() top_scores = topk.values.tolist() top_labels = [id2label[i] for i in top_indices] return {label: round(score, 4) for label, score in zip(top_labels, top_scores)} except Exception as e: import traceback print(traceback.format_exc()) return {"Error": str(e)} # Gradio UI with gr.Blocks() as demo: gr.Markdown("# PhytoViT - IFCB Phytoplankton Classifier") gr.Markdown("Upload an image or paste a URL. Model: `phytoViT_508k_20250611`") with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload Image") url_input = gr.Textbox(label="...or paste image URL") predict_btn = gr.Button("Classify") label_output = gr.Label(label="Top 5 Predictions") predict_btn.click(fn=predict, inputs=image_input, outputs=label_output) demo.launch()