Spaces:
Running
Running
File size: 1,041 Bytes
4e9208d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 | import os
import torch
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification
CLASSIFIER_MODEL_ID = os.getenv("CLASSIFIER_MODEL_ID", "your-username/plant-genus-classifier")
_processor = None
_model = None
def _load():
global _processor, _model
if _model is None:
_processor = AutoImageProcessor.from_pretrained(CLASSIFIER_MODEL_ID)
_model = AutoModelForImageClassification.from_pretrained(CLASSIFIER_MODEL_ID)
_model.eval()
return _processor, _model
def classify_plant(image: Image.Image) -> tuple[str, float]:
"""Run the fine-tuned genus classifier on an uploaded image.
Returns:
(genus_name, confidence_score)
"""
processor, model = _load()
inputs = processor(images=image.convert("RGB"), return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=-1)
top_id = int(probs.argmax(dim=-1))
return model.config.id2label[top_id], probs[0, top_id].item()
|