# app/models/plant_vision.py import torch from PIL import Image from transformers import AutoImageProcessor, AutoModelForImageClassification from app.models.llm import explain_species from app.utils.config import DEVICE, PLANT_MODEL_ID processor = AutoImageProcessor.from_pretrained(PLANT_MODEL_ID) model = AutoModelForImageClassification.from_pretrained( PLANT_MODEL_ID ).to(DEVICE) model.eval() @torch.no_grad() def predict_plant(image: Image.Image): """ Returns: { "species": str, "common_name": str | None, "confidence": float, "description": str } """ inputs = processor( images=image, return_tensors="pt" ).to(DEVICE) outputs = model(**inputs) probs = torch.softmax(outputs.logits, dim=-1) confidence, idx = probs.max(dim=-1) label = model.config.id2label[idx.item()] llm_result = explain_species( species=label, confidence=confidence.item(), domain="plant" ) return { "species": label, "common_name": llm_result["common_name"], "confidence": confidence.item(), "description": llm_result["description"], }