import os import torch from PIL import Image from transformers import AutoImageProcessor, AutoModelForImageClassification CLASSIFIER_MODEL_ID = os.getenv("CLASSIFIER_MODEL_ID", "your-username/plant-genus-classifier") _processor = None _model = None def _load(): global _processor, _model if _model is None: _processor = AutoImageProcessor.from_pretrained(CLASSIFIER_MODEL_ID) _model = AutoModelForImageClassification.from_pretrained(CLASSIFIER_MODEL_ID) _model.eval() return _processor, _model def classify_plant(image: Image.Image) -> tuple[str, float]: """Run the fine-tuned genus classifier on an uploaded image. Returns: (genus_name, confidence_score) """ processor, model = _load() inputs = processor(images=image.convert("RGB"), return_tensors="pt") with torch.no_grad(): logits = model(**inputs).logits probs = torch.softmax(logits, dim=-1) top_id = int(probs.argmax(dim=-1)) return model.config.id2label[top_id], probs[0, top_id].item()