HackatonSmall / modules /classifier.py
Crocolil's picture
Upload folder using huggingface_hub
4e9208d verified
Raw
History Blame Contribute Delete
1.04 kB
import os
import torch
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification
CLASSIFIER_MODEL_ID = os.getenv("CLASSIFIER_MODEL_ID", "your-username/plant-genus-classifier")
_processor = None
_model = None
def _load():
global _processor, _model
if _model is None:
_processor = AutoImageProcessor.from_pretrained(CLASSIFIER_MODEL_ID)
_model = AutoModelForImageClassification.from_pretrained(CLASSIFIER_MODEL_ID)
_model.eval()
return _processor, _model
def classify_plant(image: Image.Image) -> tuple[str, float]:
"""Run the fine-tuned genus classifier on an uploaded image.
Returns:
(genus_name, confidence_score)
"""
processor, model = _load()
inputs = processor(images=image.convert("RGB"), return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=-1)
top_id = int(probs.argmax(dim=-1))
return model.config.id2label[top_id], probs[0, top_id].item()