AgroSense / src /utils /leaf_classifier.py
ItzRoBeerT's picture
Added leaves classifier
6bc48aa
"""
Plant Disease Classifier
=========================
Classifies plant leaf diseases using MobileNetV2.
Model: linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification
- 38 classes (26 diseases + 12 healthy plants)
- 99.47% accuracy on PlantVillage dataset
- Input: 224x224 RGB image
Usage:
from src.classifier import predict
result = predict(pil_image)
print(result["prediction"]) # "Tomato - Late Blight"
"""
import torch
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification
# ============================================================
# CONFIGURATION
# ============================================================
MODEL_NAME = "linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification"
# ============================================================
# MODULE STATE
# ============================================================
_model = None
_processor = None
_device = None
# ============================================================
# PRIVATE FUNCTIONS
# ============================================================
def _load_model():
"""
Load model and processor from HuggingFace.
Executes only ONCE (lazy loading).
Subsequent calls return cached objects.
Returns:
tuple: (model, processor, device)
"""
global _model, _processor, _device
# Return cached if already loaded
if _model is not None:
return _model, _processor, _device
print("๐ŸŒฑ Loading classification model...")
# Determine device (GPU or CPU)
_device = "cuda" if torch.cuda.is_available() else "cpu"
print(f" Device: {_device}")
# Load processor (prepares images for model)
_processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
# Load model
_model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
_model.to(_device)
_model.eval() # Set to evaluation mode
print(f"โœ… Model loaded: {len(_model.config.id2label)} classes")
return _model, _processor, _device
def _parse_label(raw_label: str) -> tuple:
"""
Parse raw model label into (plant, disease).
Args:
raw_label: Model label, e.g. "Tomato___Late_blight"
Returns:
tuple: (plant, disease)
e.g. ("Tomato", "Late blight")
"""
try:
# Split by triple underscore
parts = raw_label.split("___")
plant = parts[0].replace("_", " ").replace("(", "").replace(")", "").strip()
if len(parts) > 1:
disease = parts[1].replace("_", " ").strip()
# Capitalize properly
disease = disease.title() if disease.lower() != "healthy" else "Healthy"
else:
disease = "Unknown"
return (plant, disease)
except:
return (raw_label, "Unknown")
# ============================================================
# PUBLIC FUNCTION
# ============================================================
def predict(image: Image.Image, top_k: int = 3) -> dict:
"""
Predict disease in a plant leaf image.
Args:
image: PIL Image (PIL.Image.Image)
top_k: Number of alternative predictions to return
Returns:
dict with result:
{
"success": True,
"prediction": "Tomato - Late Blight",
"confidence": 95.23,
"is_healthy": False,
"plant": "Tomato",
"disease": "Late Blight",
"raw_label": "Tomato___Late_blight",
"top_k": [
{"plant": "Tomato", "disease": "Late Blight", "confidence": 95.23},
...
]
}
On error:
{
"success": False,
"error": "Error description"
}
"""
# Validate input
if image is None:
return {
"success": False,
"error": "No image provided"
}
if not isinstance(image, Image.Image):
return {
"success": False,
"error": f"Invalid image type: {type(image)}. Expected PIL.Image"
}
try:
# Load model (only first time)
model, processor, device = _load_model()
# Preprocess image
image = image.convert("RGB")
inputs = processor(images=image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
# Inference
with torch.no_grad():
outputs = model(**inputs)
# Process results
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=-1)
# Get top prediction
top_prob, top_idx = torch.max(probs, dim=-1)
raw_label = model.config.id2label[top_idx.item()]
confidence = round(top_prob.item() * 100, 2)
# Parse label
plant, disease = _parse_label(raw_label)
is_healthy = "healthy" in raw_label.lower()
# Get top-k predictions
top_k_probs, top_k_indices = torch.topk(probs, min(top_k, probs.shape[-1]))
top_k_results = []
for idx, prob in zip(top_k_indices[0], top_k_probs[0]):
label = model.config.id2label[idx.item()]
p, d = _parse_label(label)
top_k_results.append({
"plant": p,
"disease": d,
"confidence": round(prob.item() * 100, 2),
"raw_label": label
})
# Return structured result
return {
"success": True,
"prediction": f"{plant} - {disease}",
"confidence": confidence,
"is_healthy": is_healthy,
"plant": plant,
"disease": disease,
"raw_label": raw_label,
"top_k": top_k_results
}
except Exception as e:
return {
"success": False,
"error": str(e)
}
print("\n" + "="*50)
print("๐Ÿงช CLASSIFIER TEST")
print("="*50)
model, processor, device = _load_model()
print(f"\n๐Ÿ“Š Available classes: {len(model.config.id2label)}")
print(f"๐Ÿ–ฅ๏ธ Device: {device}")
print("\n๐Ÿ“‹ Sample classes:")
for i, (idx, label) in enumerate(list(model.config.id2label.items())[:5]):
plant, disease = _parse_label(label)
print(f" {idx}: {plant} - {disease}")
print("\nโœ… Classifier ready")
print("="*50)