Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from transformers import ViTForImageClassification | |
| import os | |
| from safetensors.torch import load_file | |
| from huggingface_hub import hf_hub_download | |
| from torchvision import transforms | |
| import json | |
| # Download the file from your model repo (replace with your actual token if private) | |
| model_path = hf_hub_download( | |
| repo_id="patcdaniel/UCSCPhytoViT83", | |
| filename="model.safetensors", | |
| token=os.environ.get("HF_TOKEN") # omit this line if public | |
| ) | |
| state_dict = load_file(model_path) | |
| model = ViTForImageClassification.from_pretrained( | |
| "google/vit-base-patch16-224-in21k", | |
| num_labels=83 # this must match your training | |
| ) | |
| model.load_state_dict(state_dict) | |
| model_path = hf_hub_download( | |
| repo_id="patcdaniel/UCSCPhytoViT83", | |
| filename="label_names.json", | |
| token=os.environ.get("HF_TOKEN"), | |
| local_dir="." | |
| ) | |
| # Load class label dictionary (label -> index) | |
| with open(model_path, "r") as f: | |
| id2label = {int(k): v for k, v in json.load(f).items()} | |
| # Convert to id -> label | |
| # id2label = {v: k for k, v in label2id.items()} | |
| # Use GPU if available | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| # Inference function | |
| def predict(image): | |
| try: | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), # match ViT input size | |
| transforms.ToTensor(), # Converts PIL.Image to torch.Tensor | |
| transforms.Normalize(mean=(0.485, 0.456, 0.406), | |
| std=(0.229, 0.224, 0.225)) | |
| ]) | |
| pixel_values = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| logits = model(pixel_values).logits | |
| probs = torch.nn.functional.softmax(logits, dim=1).squeeze() | |
| topk = torch.topk(probs, k=3) | |
| top_indices = topk.indices.tolist() | |
| top_scores = topk.values.tolist() | |
| top_labels = [id2label[i] for i in top_indices] | |
| return {label: round(score, 4) for label, score in zip(top_labels, top_scores)} | |
| except Exception as e: | |
| import traceback | |
| print(traceback.format_exc()) | |
| return {"Error": str(e)} | |
| # Gradio UI | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# PhytoViT - IFCB Phytoplankton Classifier") | |
| gr.Markdown("Upload an image or paste a URL. Model: `phytoViT_508k_20250611`") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| url_input = gr.Textbox(label="...or paste image URL") | |
| predict_btn = gr.Button("Classify") | |
| label_output = gr.Label(label="Top 5 Predictions") | |
| predict_btn.click(fn=predict, inputs=image_input, outputs=label_output) | |
| demo.launch() |