File size: 1,041 Bytes
4e9208d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import os

import torch
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification

CLASSIFIER_MODEL_ID = os.getenv("CLASSIFIER_MODEL_ID", "your-username/plant-genus-classifier")

_processor = None
_model = None


def _load():
    global _processor, _model
    if _model is None:
        _processor = AutoImageProcessor.from_pretrained(CLASSIFIER_MODEL_ID)
        _model = AutoModelForImageClassification.from_pretrained(CLASSIFIER_MODEL_ID)
        _model.eval()
    return _processor, _model


def classify_plant(image: Image.Image) -> tuple[str, float]:
    """Run the fine-tuned genus classifier on an uploaded image.

    Returns:
        (genus_name, confidence_score)
    """
    processor, model = _load()
    inputs = processor(images=image.convert("RGB"), return_tensors="pt")
    with torch.no_grad():
        logits = model(**inputs).logits
    probs = torch.softmax(logits, dim=-1)
    top_id = int(probs.argmax(dim=-1))
    return model.config.id2label[top_id], probs[0, top_id].item()