File size: 2,761 Bytes
412cf06
 
4fc5a19
ff45b18
f3969d0
 
 
 
ff45b18
 
f3969d0
 
fd25d05
f3969d0
 
 
 
 
 
 
fd25d05
f3969d0
 
 
 
fd25d05
f3969d0
 
 
 
 
 
 
 
 
 
 
412cf06
7cea9f7
 
 
412cf06
7cea9f7
412cf06
7cea9f7
f3969d0
f728991
4fc5a19
f728991
 
f3969d0
 
 
7cea9f7
 
f3969d0
f728991
 
7cea9f7
4fc5a19
7cea9f7
 
 
f3969d0
7cea9f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3969d0
7cea9f7
 
412cf06
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import gradio as gr
import torch
from transformers import ViTForImageClassification
import os
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from torchvision import transforms
import json


# Download the file from your model repo (replace with your actual token if private)
model_path = hf_hub_download(
    repo_id="patcdaniel/UCSCPhytoViT83",
    filename="model.safetensors",
    token=os.environ.get("HF_TOKEN")  # omit this line if public
)
state_dict = load_file(model_path)

model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    num_labels=83  # this must match your training
)
model.load_state_dict(state_dict)

model_path = hf_hub_download(
    repo_id="patcdaniel/UCSCPhytoViT83",
    filename="label_names.json",
    token=os.environ.get("HF_TOKEN"),
    local_dir="."
)

# Load class label dictionary (label -> index)
with open(model_path, "r") as f:
    id2label = {int(k): v for k, v in json.load(f).items()}

# Convert to id -> label
# id2label = {v: k for k, v in label2id.items()}

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Inference function
def predict(image):
    try:
        transform = transforms.Compose([
                    transforms.Resize((224, 224)),  # match ViT input size
                        transforms.ToTensor(),  # Converts PIL.Image to torch.Tensor
                        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225))
        ])

        pixel_values = transform(image).unsqueeze(0).to(device)

        with torch.no_grad():
            logits = model(pixel_values).logits
            probs = torch.nn.functional.softmax(logits, dim=1).squeeze()


        topk = torch.topk(probs, k=3)
        top_indices = topk.indices.tolist()
        top_scores = topk.values.tolist()

        top_labels = [id2label[i] for i in top_indices]

        return {label: round(score, 4) for label, score in zip(top_labels, top_scores)}

    except Exception as e:
        import traceback
        print(traceback.format_exc())
        return {"Error": str(e)}

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# PhytoViT - IFCB Phytoplankton Classifier")
    gr.Markdown("Upload an image or paste a URL. Model: `phytoViT_508k_20250611`")

    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil", label="Upload Image")
            url_input = gr.Textbox(label="...or paste image URL")
            predict_btn = gr.Button("Classify")
            label_output = gr.Label(label="Top 5 Predictions")

    predict_btn.click(fn=predict, inputs=image_input, outputs=label_output)

demo.launch()