drrobot9's picture
Update app/models/plant_vision.py
b6d92ca verified
# app/models/plant_vision.py
import torch
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification
from app.models.llm import explain_species
from app.utils.config import DEVICE, PLANT_MODEL_ID
processor = AutoImageProcessor.from_pretrained(PLANT_MODEL_ID)
model = AutoModelForImageClassification.from_pretrained(
PLANT_MODEL_ID
).to(DEVICE)
model.eval()
@torch.no_grad()
def predict_plant(image: Image.Image):
"""
Returns:
{
"species": str,
"common_name": str | None,
"confidence": float,
"description": str
}
"""
inputs = processor(
images=image,
return_tensors="pt"
).to(DEVICE)
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)
confidence, idx = probs.max(dim=-1)
label = model.config.id2label[idx.item()]
llm_result = explain_species(
species=label,
confidence=confidence.item(),
domain="plant"
)
return {
"species": label,
"common_name": llm_result["common_name"],
"confidence": confidence.item(),
#"description": llm_result["description"],
}