Spaces:
Sleeping
Sleeping
| # app/models/plant_vision.py | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| from app.models.llm import explain_species | |
| from app.utils.config import DEVICE, PLANT_MODEL_ID | |
| processor = AutoImageProcessor.from_pretrained(PLANT_MODEL_ID) | |
| model = AutoModelForImageClassification.from_pretrained( | |
| PLANT_MODEL_ID | |
| ).to(DEVICE) | |
| model.eval() | |
| def predict_plant(image: Image.Image): | |
| """ | |
| Returns: | |
| { | |
| "species": str, | |
| "common_name": str | None, | |
| "confidence": float, | |
| "description": str | |
| } | |
| """ | |
| inputs = processor( | |
| images=image, | |
| return_tensors="pt" | |
| ).to(DEVICE) | |
| outputs = model(**inputs) | |
| probs = torch.softmax(outputs.logits, dim=-1) | |
| confidence, idx = probs.max(dim=-1) | |
| label = model.config.id2label[idx.item()] | |
| llm_result = explain_species( | |
| species=label, | |
| confidence=confidence.item(), | |
| domain="plant" | |
| ) | |
| return { | |
| "species": label, | |
| "common_name": llm_result["common_name"], | |
| "confidence": confidence.item(), | |
| #"description": llm_result["description"], | |
| } | |