Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| CLASSIFIER_MODEL_ID = os.getenv("CLASSIFIER_MODEL_ID", "your-username/plant-genus-classifier") | |
| _processor = None | |
| _model = None | |
| def _load(): | |
| global _processor, _model | |
| if _model is None: | |
| _processor = AutoImageProcessor.from_pretrained(CLASSIFIER_MODEL_ID) | |
| _model = AutoModelForImageClassification.from_pretrained(CLASSIFIER_MODEL_ID) | |
| _model.eval() | |
| return _processor, _model | |
| def classify_plant(image: Image.Image) -> tuple[str, float]: | |
| """Run the fine-tuned genus classifier on an uploaded image. | |
| Returns: | |
| (genus_name, confidence_score) | |
| """ | |
| processor, model = _load() | |
| inputs = processor(images=image.convert("RGB"), return_tensors="pt") | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| probs = torch.softmax(logits, dim=-1) | |
| top_id = int(probs.argmax(dim=-1)) | |
| return model.config.id2label[top_id], probs[0, top_id].item() | |