FarmEyes / models /yolo_model.py
Fola-AI
Initial FarmEyes deployment - AI Powered Crop Disease Detection Program
f45df09
"""
FarmEyes YOLOv11 Model Integration
==================================
Handles loading and inference with YOLOv11 model for crop disease detection.
Optimized for Apple Silicon M1 Pro with MPS (Metal Performance Shaders) acceleration.
Model: Custom trained YOLOv11 for 6 disease classes (no healthy classes)
Crops: Cassava, Cocoa, Tomato
Classes:
0: Cassava Bacteria Blight
1: Cassava Mosaic Virus
2: Cocoa Monilia Disease
3: Cocoa Phytophthora Disease
4: Tomato Gray Mold Disease
5: Tomato Wilt Disease
"""
import os
import sys
from pathlib import Path
from typing import Optional, Dict, List, Tuple, Union
from dataclasses import dataclass
import logging
# Add parent directory to path for imports
sys.path.append(str(Path(__file__).parent.parent))
import numpy as np
from PIL import Image
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# =============================================================================
# PREDICTION RESULT DATACLASS
# =============================================================================
@dataclass
class PredictionResult:
"""
Container for disease prediction results.
"""
class_index: int # Index of predicted class (0-5)
class_name: str # Human-readable class name
disease_key: str # Key for knowledge base lookup
confidence: float # Confidence score (0.0 - 1.0)
crop_type: str # Crop type (cassava, cocoa, tomato)
is_healthy: bool # Whether plant is healthy (always False in 6-class model)
bbox: Optional[List[float]] = None # Bounding box [x1, y1, x2, y2] if available
def to_dict(self) -> Dict:
"""Convert to dictionary for JSON serialization."""
return {
"class_index": self.class_index,
"class_name": self.class_name,
"disease_key": self.disease_key,
"confidence": round(self.confidence, 4),
"confidence_percent": round(self.confidence * 100, 1),
"crop_type": self.crop_type,
"is_healthy": self.is_healthy,
"bbox": self.bbox
}
def __repr__(self) -> str:
return f"PredictionResult({self.class_name}, conf={self.confidence:.2%}, crop={self.crop_type})"
# =============================================================================
# YOLO MODEL CLASS
# =============================================================================
class YOLOModel:
"""
YOLOv11 Model wrapper for FarmEyes crop disease detection.
Uses Ultralytics library with MPS acceleration for Apple Silicon.
6-class model (all diseases, no healthy classes):
0: Cassava Bacteria Blight
1: Cassava Mosaic Virus
2: Cocoa Monilia Disease
3: Cocoa Phytophthora Disease
4: Tomato Gray Mold Disease
5: Tomato Wilt Disease
"""
# Class mappings (must match your trained model - 6 classes)
CLASS_NAMES: List[str] = [
"Cassava Bacteria Blight", # Index 0
"Cassava Mosaic Virus", # Index 1
"Cocoa Monilia Disease", # Index 2
"Cocoa Phytophthora Disease", # Index 3
"Tomato Gray Mold Disease", # Index 4
"Tomato Wilt Disease" # Index 5
]
# Class index to knowledge base key mapping (6 classes)
CLASS_TO_KEY: Dict[int, str] = {
0: "cassava_bacterial_blight",
1: "cassava_mosaic_virus",
2: "cocoa_monilia_disease",
3: "cocoa_phytophthora_disease",
4: "tomato_gray_mold",
5: "tomato_wilt_disease"
}
# Class index to crop type mapping (6 classes)
CLASS_TO_CROP: Dict[int, str] = {
0: "cassava", # Cassava Bacteria Blight
1: "cassava", # Cassava Mosaic Virus
2: "cocoa", # Cocoa Monilia Disease
3: "cocoa", # Cocoa Phytophthora Disease
4: "tomato", # Tomato Gray Mold Disease
5: "tomato" # Tomato Wilt Disease
}
# No healthy class indices in 6-class model (all classes are diseases)
HEALTHY_INDICES: List[int] = []
def __init__(
self,
model_path: Optional[str] = None,
confidence_threshold: float = 0.5,
iou_threshold: float = 0.45,
device: str = "mps",
input_size: int = 640
):
"""
Initialize YOLOv11 model.
Args:
model_path: Path to trained YOLOv11 .pt weights file
confidence_threshold: Minimum confidence for detections
iou_threshold: IoU threshold for NMS
device: Compute device ('mps' for Apple Silicon, 'cuda', 'cpu')
input_size: Input image size for the model
"""
# Import config here to avoid circular imports
from config import yolo_config, MODELS_DIR
self.model_path = model_path or str(yolo_config.model_path)
self.confidence_threshold = confidence_threshold
self.iou_threshold = iou_threshold
self.input_size = input_size
# Determine best device
self.device = self._get_best_device(device)
# Model instance (lazy loaded)
self._model = None
self._is_loaded = False
logger.info(f"YOLOModel initialized:")
logger.info(f" Model path: {self.model_path}")
logger.info(f" Device: {self.device}")
logger.info(f" Confidence threshold: {self.confidence_threshold}")
logger.info(f" Input size: {self.input_size}")
logger.info(f" Number of classes: {len(self.CLASS_NAMES)}")
# =========================================================================
# DEVICE MANAGEMENT
# =========================================================================
def _get_best_device(self, preferred: str = "mps") -> str:
"""
Determine the best available compute device.
Args:
preferred: Preferred device ('mps', 'cuda', 'cpu')
Returns:
Best available device string
"""
import torch
if preferred == "mps" and torch.backends.mps.is_available():
logger.info("Using MPS (Metal Performance Shaders) for Apple Silicon")
return "mps"
elif preferred == "cuda" and torch.cuda.is_available():
logger.info(f"Using CUDA: {torch.cuda.get_device_name(0)}")
return "cuda"
else:
logger.info("Using CPU for inference")
return "cpu"
# =========================================================================
# MODEL LOADING
# =========================================================================
def load_model(self) -> bool:
"""
Load the YOLOv11 model into memory.
Returns:
True if model loaded successfully
"""
if self._is_loaded:
logger.info("Model already loaded")
return True
try:
from ultralytics import YOLO
# Check if model file exists
if not Path(self.model_path).exists():
logger.warning(f"Model file not found at {self.model_path}")
logger.warning("Using placeholder - please provide trained model")
# Create a placeholder with pretrained YOLOv11n for testing
# Replace this with your actual trained model
logger.info("Loading pretrained YOLOv11n as placeholder...")
self._model = YOLO("yolo11n.pt") # Downloads pretrained model
self._is_placeholder = True
else:
logger.info(f"Loading YOLOv11 model from {self.model_path}...")
self._model = YOLO(self.model_path)
self._is_placeholder = False
# Move model to device
self._model.to(self.device)
self._is_loaded = True
logger.info(f"✅ YOLOv11 model loaded successfully on {self.device}!")
return True
except ImportError:
logger.error("Ultralytics not installed!")
logger.error("Install with: pip install ultralytics")
raise ImportError("ultralytics package is required")
except Exception as e:
logger.error(f"Failed to load model: {e}")
self._is_loaded = False
raise RuntimeError(f"Could not load YOLOv11 model: {e}")
def unload_model(self):
"""Unload model from memory."""
if self._model is not None:
del self._model
self._model = None
self._is_loaded = False
# Clear GPU cache
import torch
if self.device == "mps":
torch.mps.empty_cache()
elif self.device == "cuda":
torch.cuda.empty_cache()
logger.info("Model unloaded from memory")
@property
def is_loaded(self) -> bool:
"""Check if model is loaded."""
return self._is_loaded
# =========================================================================
# IMAGE PREPROCESSING
# =========================================================================
def preprocess_image(
self,
image: Union[str, Path, Image.Image, np.ndarray]
) -> Image.Image:
"""
Preprocess image for inference.
Args:
image: Input image (path, PIL Image, or numpy array)
Returns:
PIL Image ready for inference
"""
# Handle different input types
if isinstance(image, (str, Path)):
image_path = Path(image)
if not image_path.exists():
raise FileNotFoundError(f"Image not found: {image_path}")
pil_image = Image.open(image_path)
elif isinstance(image, np.ndarray):
pil_image = Image.fromarray(image)
elif isinstance(image, Image.Image):
pil_image = image
else:
raise TypeError(f"Unsupported image type: {type(image)}")
# Convert to RGB if necessary
if pil_image.mode != "RGB":
pil_image = pil_image.convert("RGB")
return pil_image
def validate_image(self, image: Image.Image) -> Tuple[bool, str]:
"""
Validate image for inference.
Args:
image: PIL Image to validate
Returns:
Tuple of (is_valid, message)
"""
# Check image size
width, height = image.size
if width < 32 or height < 32:
return False, "Image too small. Minimum size is 32x32 pixels."
if width > 4096 or height > 4096:
return False, "Image too large. Maximum size is 4096x4096 pixels."
return True, "Image is valid"
# =========================================================================
# INFERENCE
# =========================================================================
def predict(
self,
image: Union[str, Path, Image.Image, np.ndarray]
) -> PredictionResult:
"""
Run disease detection on an image.
Args:
image: Input image (path, PIL Image, or numpy array)
Returns:
PredictionResult with disease information
"""
if not self._is_loaded:
self.load_model()
# Preprocess image
pil_image = self.preprocess_image(image)
# Validate image
is_valid, message = self.validate_image(pil_image)
if not is_valid:
logger.warning(f"Image validation failed: {message}")
return self._create_low_confidence_result()
try:
# Run inference
results = self._model(
pil_image,
conf=self.confidence_threshold,
iou=self.iou_threshold,
imgsz=self.input_size,
device=self.device,
verbose=False
)
# Parse results
predictions = self._parse_results(results)
if not predictions:
logger.info("No predictions above confidence threshold")
return self._create_low_confidence_result()
# Return top prediction
return predictions[0]
except Exception as e:
logger.error(f"Inference failed: {e}")
return self._create_low_confidence_result()
def predict_with_visualization(
self,
image: Union[str, Path, Image.Image, np.ndarray]
) -> Tuple[PredictionResult, Image.Image]:
"""
Run detection and return annotated image.
Args:
image: Input image
Returns:
Tuple of (PredictionResult, annotated PIL Image)
"""
if not self._is_loaded:
self.load_model()
# Preprocess image
pil_image = self.preprocess_image(image)
# Validate image
is_valid, message = self.validate_image(pil_image)
if not is_valid:
logger.warning(f"Image validation failed: {message}")
return self._create_low_confidence_result(), pil_image
try:
# Run inference
results = self._model(
pil_image,
conf=self.confidence_threshold,
iou=self.iou_threshold,
imgsz=self.input_size,
device=self.device,
verbose=False
)
# Parse results
predictions = self._parse_results(results)
# Get annotated image
annotated = results[0].plot()
annotated_pil = Image.fromarray(annotated[..., ::-1]) # BGR to RGB
if not predictions:
return self._create_low_confidence_result(), annotated_pil
return predictions[0], annotated_pil
except Exception as e:
logger.error(f"Inference with visualization failed: {e}")
return self._create_low_confidence_result(), pil_image
def _parse_results(self, results) -> List[PredictionResult]:
"""
Parse YOLO results into PredictionResult objects.
Args:
results: YOLO inference results
Returns:
List of PredictionResult objects sorted by confidence
"""
predictions = []
for result in results:
# Check if we have classification results (for classification model)
if hasattr(result, 'probs') and result.probs is not None:
probs = result.probs
# Get top prediction
top_idx = int(probs.top1)
top_conf = float(probs.top1conf)
# Handle placeholder model (pretrained YOLO)
if hasattr(self, '_is_placeholder') and self._is_placeholder:
# Map to our classes for demo purposes
top_idx = top_idx % len(self.CLASS_NAMES)
if top_idx < len(self.CLASS_NAMES):
prediction = PredictionResult(
class_index=top_idx,
class_name=self.CLASS_NAMES[top_idx],
disease_key=self.CLASS_TO_KEY[top_idx],
confidence=top_conf,
crop_type=self.CLASS_TO_CROP[top_idx],
is_healthy=top_idx in self.HEALTHY_INDICES # Always False for 6-class model
)
predictions.append(prediction)
# Check for detection results (for detection model)
elif hasattr(result, 'boxes') and result.boxes is not None:
boxes = result.boxes
for i in range(len(boxes)):
cls_idx = int(boxes.cls[i])
conf = float(boxes.conf[i])
bbox = boxes.xyxy[i].tolist() if boxes.xyxy is not None else None
# Handle placeholder model
if hasattr(self, '_is_placeholder') and self._is_placeholder:
cls_idx = cls_idx % len(self.CLASS_NAMES)
if cls_idx < len(self.CLASS_NAMES):
prediction = PredictionResult(
class_index=cls_idx,
class_name=self.CLASS_NAMES[cls_idx],
disease_key=self.CLASS_TO_KEY[cls_idx],
confidence=conf,
crop_type=self.CLASS_TO_CROP[cls_idx],
is_healthy=cls_idx in self.HEALTHY_INDICES, # Always False for 6-class model
bbox=bbox
)
predictions.append(prediction)
# Sort by confidence (highest first)
predictions.sort(key=lambda x: x.confidence, reverse=True)
return predictions
def _create_low_confidence_result(self) -> PredictionResult:
"""Create a result indicating low confidence / no detection."""
return PredictionResult(
class_index=-1,
class_name="Unknown",
disease_key="unknown",
confidence=0.0,
crop_type="unknown",
is_healthy=False
)
# =========================================================================
# BATCH INFERENCE
# =========================================================================
def predict_batch(
self,
images: List[Union[str, Path, Image.Image, np.ndarray]]
) -> List[PredictionResult]:
"""
Run detection on multiple images.
Args:
images: List of input images
Returns:
List of PredictionResult objects (one per image)
"""
if not self._is_loaded:
self.load_model()
results = []
for image in images:
try:
result = self.predict(image)
results.append(result)
except Exception as e:
logger.error(f"Failed to process image: {e}")
results.append(self._create_low_confidence_result())
return results
# =========================================================================
# UTILITY METHODS
# =========================================================================
def get_class_info(self, class_index: int) -> Dict:
"""
Get information about a class by index.
Args:
class_index: Index of the class (0-5)
Returns:
Dictionary with class information
"""
if class_index < 0 or class_index >= len(self.CLASS_NAMES):
return {
"class_index": class_index,
"class_name": "Unknown",
"disease_key": "unknown",
"crop_type": "unknown",
"is_healthy": False
}
return {
"class_index": class_index,
"class_name": self.CLASS_NAMES[class_index],
"disease_key": self.CLASS_TO_KEY[class_index],
"crop_type": self.CLASS_TO_CROP[class_index],
"is_healthy": class_index in self.HEALTHY_INDICES # Always False for 6-class model
}
def get_model_info(self) -> Dict:
"""Get information about the loaded model."""
info = {
"model_path": self.model_path,
"is_loaded": self._is_loaded,
"device": self.device,
"confidence_threshold": self.confidence_threshold,
"input_size": self.input_size,
"num_classes": len(self.CLASS_NAMES),
"classes": self.CLASS_NAMES
}
if self._is_loaded and hasattr(self, '_is_placeholder'):
info["is_placeholder"] = self._is_placeholder
return info
# =============================================================================
# SINGLETON INSTANCE
# =============================================================================
_model_instance: Optional[YOLOModel] = None
def get_yolo_model() -> YOLOModel:
"""
Get the singleton YOLO model instance.
Returns:
YOLOModel instance
"""
global _model_instance
if _model_instance is None:
from config import yolo_config
_model_instance = YOLOModel(
model_path=str(yolo_config.model_path),
confidence_threshold=yolo_config.confidence_threshold,
iou_threshold=yolo_config.iou_threshold,
device=yolo_config.device,
input_size=yolo_config.input_size
)
return _model_instance
def unload_yolo_model():
"""Unload the singleton YOLO model to free memory."""
global _model_instance
if _model_instance is not None:
_model_instance.unload_model()
_model_instance = None
# =============================================================================
# CONVENIENCE FUNCTIONS
# =============================================================================
def detect_disease(
image: Union[str, Path, Image.Image, np.ndarray]
) -> PredictionResult:
"""
Convenience function to detect disease in an image.
Args:
image: Input image (path, PIL Image, or numpy array)
Returns:
PredictionResult with disease information
"""
model = get_yolo_model()
return model.predict(image)
def detect_disease_with_image(
image: Union[str, Path, Image.Image, np.ndarray]
) -> Tuple[PredictionResult, Image.Image]:
"""
Detect disease and return annotated image.
Args:
image: Input image
Returns:
Tuple of (PredictionResult, annotated Image)
"""
model = get_yolo_model()
return model.predict_with_visualization(image)
# =============================================================================
# MAIN - Test the model
# =============================================================================
if __name__ == "__main__":
import torch
print("=" * 60)
print("YOLOv11 Model Test (6-Class Disease Detection)")
print("=" * 60)
# Check device
print("\n1. Checking compute device...")
print(f" PyTorch version: {torch.__version__}")
print(f" MPS available: {torch.backends.mps.is_available()}")
print(f" MPS built: {torch.backends.mps.is_built()}")
# Initialize model
print("\n2. Initializing YOLOv11 model...")
model = YOLOModel()
# Load model
print("\n3. Loading model...")
model.load_model()
# Print model info
print("\n4. Model information:")
info = model.get_model_info()
for key, value in info.items():
print(f" {key}: {value}")
# Test with a sample image (if available)
print("\n5. Testing inference...")
print(" To test with an actual image, run:")
print(" >>> result = model.predict('path/to/your/image.jpg')")
print(" >>> print(result)")
# Print class mappings
print("\n6. Class mappings (6 classes - all diseases):")
for idx, name in enumerate(model.CLASS_NAMES):
crop = model.CLASS_TO_CROP[idx]
key = model.CLASS_TO_KEY[idx]
print(f" {idx}: {name}")
print(f" Crop: {crop}")
print(f" Key: {key}")
print("\n" + "=" * 60)
print("✅ YOLOv11 model test completed!")
print("=" * 60)