|
|
""" |
|
|
FarmEyes YOLOv11 Model Integration |
|
|
================================== |
|
|
Handles loading and inference with YOLOv11 model for crop disease detection. |
|
|
Optimized for Apple Silicon M1 Pro with MPS (Metal Performance Shaders) acceleration. |
|
|
|
|
|
Model: Custom trained YOLOv11 for 6 disease classes (no healthy classes) |
|
|
Crops: Cassava, Cocoa, Tomato |
|
|
Classes: |
|
|
0: Cassava Bacteria Blight |
|
|
1: Cassava Mosaic Virus |
|
|
2: Cocoa Monilia Disease |
|
|
3: Cocoa Phytophthora Disease |
|
|
4: Tomato Gray Mold Disease |
|
|
5: Tomato Wilt Disease |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
from pathlib import Path |
|
|
from typing import Optional, Dict, List, Tuple, Union |
|
|
from dataclasses import dataclass |
|
|
import logging |
|
|
|
|
|
|
|
|
sys.path.append(str(Path(__file__).parent.parent)) |
|
|
|
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class PredictionResult: |
|
|
""" |
|
|
Container for disease prediction results. |
|
|
""" |
|
|
class_index: int |
|
|
class_name: str |
|
|
disease_key: str |
|
|
confidence: float |
|
|
crop_type: str |
|
|
is_healthy: bool |
|
|
bbox: Optional[List[float]] = None |
|
|
|
|
|
def to_dict(self) -> Dict: |
|
|
"""Convert to dictionary for JSON serialization.""" |
|
|
return { |
|
|
"class_index": self.class_index, |
|
|
"class_name": self.class_name, |
|
|
"disease_key": self.disease_key, |
|
|
"confidence": round(self.confidence, 4), |
|
|
"confidence_percent": round(self.confidence * 100, 1), |
|
|
"crop_type": self.crop_type, |
|
|
"is_healthy": self.is_healthy, |
|
|
"bbox": self.bbox |
|
|
} |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return f"PredictionResult({self.class_name}, conf={self.confidence:.2%}, crop={self.crop_type})" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class YOLOModel: |
|
|
""" |
|
|
YOLOv11 Model wrapper for FarmEyes crop disease detection. |
|
|
Uses Ultralytics library with MPS acceleration for Apple Silicon. |
|
|
|
|
|
6-class model (all diseases, no healthy classes): |
|
|
0: Cassava Bacteria Blight |
|
|
1: Cassava Mosaic Virus |
|
|
2: Cocoa Monilia Disease |
|
|
3: Cocoa Phytophthora Disease |
|
|
4: Tomato Gray Mold Disease |
|
|
5: Tomato Wilt Disease |
|
|
""" |
|
|
|
|
|
|
|
|
CLASS_NAMES: List[str] = [ |
|
|
"Cassava Bacteria Blight", |
|
|
"Cassava Mosaic Virus", |
|
|
"Cocoa Monilia Disease", |
|
|
"Cocoa Phytophthora Disease", |
|
|
"Tomato Gray Mold Disease", |
|
|
"Tomato Wilt Disease" |
|
|
] |
|
|
|
|
|
|
|
|
CLASS_TO_KEY: Dict[int, str] = { |
|
|
0: "cassava_bacterial_blight", |
|
|
1: "cassava_mosaic_virus", |
|
|
2: "cocoa_monilia_disease", |
|
|
3: "cocoa_phytophthora_disease", |
|
|
4: "tomato_gray_mold", |
|
|
5: "tomato_wilt_disease" |
|
|
} |
|
|
|
|
|
|
|
|
CLASS_TO_CROP: Dict[int, str] = { |
|
|
0: "cassava", |
|
|
1: "cassava", |
|
|
2: "cocoa", |
|
|
3: "cocoa", |
|
|
4: "tomato", |
|
|
5: "tomato" |
|
|
} |
|
|
|
|
|
|
|
|
HEALTHY_INDICES: List[int] = [] |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_path: Optional[str] = None, |
|
|
confidence_threshold: float = 0.5, |
|
|
iou_threshold: float = 0.45, |
|
|
device: str = "mps", |
|
|
input_size: int = 640 |
|
|
): |
|
|
""" |
|
|
Initialize YOLOv11 model. |
|
|
|
|
|
Args: |
|
|
model_path: Path to trained YOLOv11 .pt weights file |
|
|
confidence_threshold: Minimum confidence for detections |
|
|
iou_threshold: IoU threshold for NMS |
|
|
device: Compute device ('mps' for Apple Silicon, 'cuda', 'cpu') |
|
|
input_size: Input image size for the model |
|
|
""" |
|
|
|
|
|
from config import yolo_config, MODELS_DIR |
|
|
|
|
|
self.model_path = model_path or str(yolo_config.model_path) |
|
|
self.confidence_threshold = confidence_threshold |
|
|
self.iou_threshold = iou_threshold |
|
|
self.input_size = input_size |
|
|
|
|
|
|
|
|
self.device = self._get_best_device(device) |
|
|
|
|
|
|
|
|
self._model = None |
|
|
self._is_loaded = False |
|
|
|
|
|
logger.info(f"YOLOModel initialized:") |
|
|
logger.info(f" Model path: {self.model_path}") |
|
|
logger.info(f" Device: {self.device}") |
|
|
logger.info(f" Confidence threshold: {self.confidence_threshold}") |
|
|
logger.info(f" Input size: {self.input_size}") |
|
|
logger.info(f" Number of classes: {len(self.CLASS_NAMES)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_best_device(self, preferred: str = "mps") -> str: |
|
|
""" |
|
|
Determine the best available compute device. |
|
|
|
|
|
Args: |
|
|
preferred: Preferred device ('mps', 'cuda', 'cpu') |
|
|
|
|
|
Returns: |
|
|
Best available device string |
|
|
""" |
|
|
import torch |
|
|
|
|
|
if preferred == "mps" and torch.backends.mps.is_available(): |
|
|
logger.info("Using MPS (Metal Performance Shaders) for Apple Silicon") |
|
|
return "mps" |
|
|
elif preferred == "cuda" and torch.cuda.is_available(): |
|
|
logger.info(f"Using CUDA: {torch.cuda.get_device_name(0)}") |
|
|
return "cuda" |
|
|
else: |
|
|
logger.info("Using CPU for inference") |
|
|
return "cpu" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model(self) -> bool: |
|
|
""" |
|
|
Load the YOLOv11 model into memory. |
|
|
|
|
|
Returns: |
|
|
True if model loaded successfully |
|
|
""" |
|
|
if self._is_loaded: |
|
|
logger.info("Model already loaded") |
|
|
return True |
|
|
|
|
|
try: |
|
|
from ultralytics import YOLO |
|
|
|
|
|
|
|
|
if not Path(self.model_path).exists(): |
|
|
logger.warning(f"Model file not found at {self.model_path}") |
|
|
logger.warning("Using placeholder - please provide trained model") |
|
|
|
|
|
|
|
|
|
|
|
logger.info("Loading pretrained YOLOv11n as placeholder...") |
|
|
self._model = YOLO("yolo11n.pt") |
|
|
self._is_placeholder = True |
|
|
else: |
|
|
logger.info(f"Loading YOLOv11 model from {self.model_path}...") |
|
|
self._model = YOLO(self.model_path) |
|
|
self._is_placeholder = False |
|
|
|
|
|
|
|
|
self._model.to(self.device) |
|
|
|
|
|
self._is_loaded = True |
|
|
logger.info(f"✅ YOLOv11 model loaded successfully on {self.device}!") |
|
|
|
|
|
return True |
|
|
|
|
|
except ImportError: |
|
|
logger.error("Ultralytics not installed!") |
|
|
logger.error("Install with: pip install ultralytics") |
|
|
raise ImportError("ultralytics package is required") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load model: {e}") |
|
|
self._is_loaded = False |
|
|
raise RuntimeError(f"Could not load YOLOv11 model: {e}") |
|
|
|
|
|
def unload_model(self): |
|
|
"""Unload model from memory.""" |
|
|
if self._model is not None: |
|
|
del self._model |
|
|
self._model = None |
|
|
self._is_loaded = False |
|
|
|
|
|
|
|
|
import torch |
|
|
if self.device == "mps": |
|
|
torch.mps.empty_cache() |
|
|
elif self.device == "cuda": |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
logger.info("Model unloaded from memory") |
|
|
|
|
|
@property |
|
|
def is_loaded(self) -> bool: |
|
|
"""Check if model is loaded.""" |
|
|
return self._is_loaded |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess_image( |
|
|
self, |
|
|
image: Union[str, Path, Image.Image, np.ndarray] |
|
|
) -> Image.Image: |
|
|
""" |
|
|
Preprocess image for inference. |
|
|
|
|
|
Args: |
|
|
image: Input image (path, PIL Image, or numpy array) |
|
|
|
|
|
Returns: |
|
|
PIL Image ready for inference |
|
|
""" |
|
|
|
|
|
if isinstance(image, (str, Path)): |
|
|
image_path = Path(image) |
|
|
if not image_path.exists(): |
|
|
raise FileNotFoundError(f"Image not found: {image_path}") |
|
|
pil_image = Image.open(image_path) |
|
|
elif isinstance(image, np.ndarray): |
|
|
pil_image = Image.fromarray(image) |
|
|
elif isinstance(image, Image.Image): |
|
|
pil_image = image |
|
|
else: |
|
|
raise TypeError(f"Unsupported image type: {type(image)}") |
|
|
|
|
|
|
|
|
if pil_image.mode != "RGB": |
|
|
pil_image = pil_image.convert("RGB") |
|
|
|
|
|
return pil_image |
|
|
|
|
|
def validate_image(self, image: Image.Image) -> Tuple[bool, str]: |
|
|
""" |
|
|
Validate image for inference. |
|
|
|
|
|
Args: |
|
|
image: PIL Image to validate |
|
|
|
|
|
Returns: |
|
|
Tuple of (is_valid, message) |
|
|
""" |
|
|
|
|
|
width, height = image.size |
|
|
|
|
|
if width < 32 or height < 32: |
|
|
return False, "Image too small. Minimum size is 32x32 pixels." |
|
|
|
|
|
if width > 4096 or height > 4096: |
|
|
return False, "Image too large. Maximum size is 4096x4096 pixels." |
|
|
|
|
|
return True, "Image is valid" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict( |
|
|
self, |
|
|
image: Union[str, Path, Image.Image, np.ndarray] |
|
|
) -> PredictionResult: |
|
|
""" |
|
|
Run disease detection on an image. |
|
|
|
|
|
Args: |
|
|
image: Input image (path, PIL Image, or numpy array) |
|
|
|
|
|
Returns: |
|
|
PredictionResult with disease information |
|
|
""" |
|
|
if not self._is_loaded: |
|
|
self.load_model() |
|
|
|
|
|
|
|
|
pil_image = self.preprocess_image(image) |
|
|
|
|
|
|
|
|
is_valid, message = self.validate_image(pil_image) |
|
|
if not is_valid: |
|
|
logger.warning(f"Image validation failed: {message}") |
|
|
return self._create_low_confidence_result() |
|
|
|
|
|
try: |
|
|
|
|
|
results = self._model( |
|
|
pil_image, |
|
|
conf=self.confidence_threshold, |
|
|
iou=self.iou_threshold, |
|
|
imgsz=self.input_size, |
|
|
device=self.device, |
|
|
verbose=False |
|
|
) |
|
|
|
|
|
|
|
|
predictions = self._parse_results(results) |
|
|
|
|
|
if not predictions: |
|
|
logger.info("No predictions above confidence threshold") |
|
|
return self._create_low_confidence_result() |
|
|
|
|
|
|
|
|
return predictions[0] |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Inference failed: {e}") |
|
|
return self._create_low_confidence_result() |
|
|
|
|
|
def predict_with_visualization( |
|
|
self, |
|
|
image: Union[str, Path, Image.Image, np.ndarray] |
|
|
) -> Tuple[PredictionResult, Image.Image]: |
|
|
""" |
|
|
Run detection and return annotated image. |
|
|
|
|
|
Args: |
|
|
image: Input image |
|
|
|
|
|
Returns: |
|
|
Tuple of (PredictionResult, annotated PIL Image) |
|
|
""" |
|
|
if not self._is_loaded: |
|
|
self.load_model() |
|
|
|
|
|
|
|
|
pil_image = self.preprocess_image(image) |
|
|
|
|
|
|
|
|
is_valid, message = self.validate_image(pil_image) |
|
|
if not is_valid: |
|
|
logger.warning(f"Image validation failed: {message}") |
|
|
return self._create_low_confidence_result(), pil_image |
|
|
|
|
|
try: |
|
|
|
|
|
results = self._model( |
|
|
pil_image, |
|
|
conf=self.confidence_threshold, |
|
|
iou=self.iou_threshold, |
|
|
imgsz=self.input_size, |
|
|
device=self.device, |
|
|
verbose=False |
|
|
) |
|
|
|
|
|
|
|
|
predictions = self._parse_results(results) |
|
|
|
|
|
|
|
|
annotated = results[0].plot() |
|
|
annotated_pil = Image.fromarray(annotated[..., ::-1]) |
|
|
|
|
|
if not predictions: |
|
|
return self._create_low_confidence_result(), annotated_pil |
|
|
|
|
|
return predictions[0], annotated_pil |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Inference with visualization failed: {e}") |
|
|
return self._create_low_confidence_result(), pil_image |
|
|
|
|
|
def _parse_results(self, results) -> List[PredictionResult]: |
|
|
""" |
|
|
Parse YOLO results into PredictionResult objects. |
|
|
|
|
|
Args: |
|
|
results: YOLO inference results |
|
|
|
|
|
Returns: |
|
|
List of PredictionResult objects sorted by confidence |
|
|
""" |
|
|
predictions = [] |
|
|
|
|
|
for result in results: |
|
|
|
|
|
if hasattr(result, 'probs') and result.probs is not None: |
|
|
probs = result.probs |
|
|
|
|
|
|
|
|
top_idx = int(probs.top1) |
|
|
top_conf = float(probs.top1conf) |
|
|
|
|
|
|
|
|
if hasattr(self, '_is_placeholder') and self._is_placeholder: |
|
|
|
|
|
top_idx = top_idx % len(self.CLASS_NAMES) |
|
|
|
|
|
if top_idx < len(self.CLASS_NAMES): |
|
|
prediction = PredictionResult( |
|
|
class_index=top_idx, |
|
|
class_name=self.CLASS_NAMES[top_idx], |
|
|
disease_key=self.CLASS_TO_KEY[top_idx], |
|
|
confidence=top_conf, |
|
|
crop_type=self.CLASS_TO_CROP[top_idx], |
|
|
is_healthy=top_idx in self.HEALTHY_INDICES |
|
|
) |
|
|
predictions.append(prediction) |
|
|
|
|
|
|
|
|
elif hasattr(result, 'boxes') and result.boxes is not None: |
|
|
boxes = result.boxes |
|
|
|
|
|
for i in range(len(boxes)): |
|
|
cls_idx = int(boxes.cls[i]) |
|
|
conf = float(boxes.conf[i]) |
|
|
bbox = boxes.xyxy[i].tolist() if boxes.xyxy is not None else None |
|
|
|
|
|
|
|
|
if hasattr(self, '_is_placeholder') and self._is_placeholder: |
|
|
cls_idx = cls_idx % len(self.CLASS_NAMES) |
|
|
|
|
|
if cls_idx < len(self.CLASS_NAMES): |
|
|
prediction = PredictionResult( |
|
|
class_index=cls_idx, |
|
|
class_name=self.CLASS_NAMES[cls_idx], |
|
|
disease_key=self.CLASS_TO_KEY[cls_idx], |
|
|
confidence=conf, |
|
|
crop_type=self.CLASS_TO_CROP[cls_idx], |
|
|
is_healthy=cls_idx in self.HEALTHY_INDICES, |
|
|
bbox=bbox |
|
|
) |
|
|
predictions.append(prediction) |
|
|
|
|
|
|
|
|
predictions.sort(key=lambda x: x.confidence, reverse=True) |
|
|
|
|
|
return predictions |
|
|
|
|
|
def _create_low_confidence_result(self) -> PredictionResult: |
|
|
"""Create a result indicating low confidence / no detection.""" |
|
|
return PredictionResult( |
|
|
class_index=-1, |
|
|
class_name="Unknown", |
|
|
disease_key="unknown", |
|
|
confidence=0.0, |
|
|
crop_type="unknown", |
|
|
is_healthy=False |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_batch( |
|
|
self, |
|
|
images: List[Union[str, Path, Image.Image, np.ndarray]] |
|
|
) -> List[PredictionResult]: |
|
|
""" |
|
|
Run detection on multiple images. |
|
|
|
|
|
Args: |
|
|
images: List of input images |
|
|
|
|
|
Returns: |
|
|
List of PredictionResult objects (one per image) |
|
|
""" |
|
|
if not self._is_loaded: |
|
|
self.load_model() |
|
|
|
|
|
results = [] |
|
|
for image in images: |
|
|
try: |
|
|
result = self.predict(image) |
|
|
results.append(result) |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to process image: {e}") |
|
|
results.append(self._create_low_confidence_result()) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_class_info(self, class_index: int) -> Dict: |
|
|
""" |
|
|
Get information about a class by index. |
|
|
|
|
|
Args: |
|
|
class_index: Index of the class (0-5) |
|
|
|
|
|
Returns: |
|
|
Dictionary with class information |
|
|
""" |
|
|
if class_index < 0 or class_index >= len(self.CLASS_NAMES): |
|
|
return { |
|
|
"class_index": class_index, |
|
|
"class_name": "Unknown", |
|
|
"disease_key": "unknown", |
|
|
"crop_type": "unknown", |
|
|
"is_healthy": False |
|
|
} |
|
|
|
|
|
return { |
|
|
"class_index": class_index, |
|
|
"class_name": self.CLASS_NAMES[class_index], |
|
|
"disease_key": self.CLASS_TO_KEY[class_index], |
|
|
"crop_type": self.CLASS_TO_CROP[class_index], |
|
|
"is_healthy": class_index in self.HEALTHY_INDICES |
|
|
} |
|
|
|
|
|
def get_model_info(self) -> Dict: |
|
|
"""Get information about the loaded model.""" |
|
|
info = { |
|
|
"model_path": self.model_path, |
|
|
"is_loaded": self._is_loaded, |
|
|
"device": self.device, |
|
|
"confidence_threshold": self.confidence_threshold, |
|
|
"input_size": self.input_size, |
|
|
"num_classes": len(self.CLASS_NAMES), |
|
|
"classes": self.CLASS_NAMES |
|
|
} |
|
|
|
|
|
if self._is_loaded and hasattr(self, '_is_placeholder'): |
|
|
info["is_placeholder"] = self._is_placeholder |
|
|
|
|
|
return info |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_model_instance: Optional[YOLOModel] = None |
|
|
|
|
|
|
|
|
def get_yolo_model() -> YOLOModel: |
|
|
""" |
|
|
Get the singleton YOLO model instance. |
|
|
|
|
|
Returns: |
|
|
YOLOModel instance |
|
|
""" |
|
|
global _model_instance |
|
|
|
|
|
if _model_instance is None: |
|
|
from config import yolo_config |
|
|
|
|
|
_model_instance = YOLOModel( |
|
|
model_path=str(yolo_config.model_path), |
|
|
confidence_threshold=yolo_config.confidence_threshold, |
|
|
iou_threshold=yolo_config.iou_threshold, |
|
|
device=yolo_config.device, |
|
|
input_size=yolo_config.input_size |
|
|
) |
|
|
|
|
|
return _model_instance |
|
|
|
|
|
|
|
|
def unload_yolo_model(): |
|
|
"""Unload the singleton YOLO model to free memory.""" |
|
|
global _model_instance |
|
|
|
|
|
if _model_instance is not None: |
|
|
_model_instance.unload_model() |
|
|
_model_instance = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def detect_disease( |
|
|
image: Union[str, Path, Image.Image, np.ndarray] |
|
|
) -> PredictionResult: |
|
|
""" |
|
|
Convenience function to detect disease in an image. |
|
|
|
|
|
Args: |
|
|
image: Input image (path, PIL Image, or numpy array) |
|
|
|
|
|
Returns: |
|
|
PredictionResult with disease information |
|
|
""" |
|
|
model = get_yolo_model() |
|
|
return model.predict(image) |
|
|
|
|
|
|
|
|
def detect_disease_with_image( |
|
|
image: Union[str, Path, Image.Image, np.ndarray] |
|
|
) -> Tuple[PredictionResult, Image.Image]: |
|
|
""" |
|
|
Detect disease and return annotated image. |
|
|
|
|
|
Args: |
|
|
image: Input image |
|
|
|
|
|
Returns: |
|
|
Tuple of (PredictionResult, annotated Image) |
|
|
""" |
|
|
model = get_yolo_model() |
|
|
return model.predict_with_visualization(image) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import torch |
|
|
|
|
|
print("=" * 60) |
|
|
print("YOLOv11 Model Test (6-Class Disease Detection)") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
print("\n1. Checking compute device...") |
|
|
print(f" PyTorch version: {torch.__version__}") |
|
|
print(f" MPS available: {torch.backends.mps.is_available()}") |
|
|
print(f" MPS built: {torch.backends.mps.is_built()}") |
|
|
|
|
|
|
|
|
print("\n2. Initializing YOLOv11 model...") |
|
|
model = YOLOModel() |
|
|
|
|
|
|
|
|
print("\n3. Loading model...") |
|
|
model.load_model() |
|
|
|
|
|
|
|
|
print("\n4. Model information:") |
|
|
info = model.get_model_info() |
|
|
for key, value in info.items(): |
|
|
print(f" {key}: {value}") |
|
|
|
|
|
|
|
|
print("\n5. Testing inference...") |
|
|
print(" To test with an actual image, run:") |
|
|
print(" >>> result = model.predict('path/to/your/image.jpg')") |
|
|
print(" >>> print(result)") |
|
|
|
|
|
|
|
|
print("\n6. Class mappings (6 classes - all diseases):") |
|
|
for idx, name in enumerate(model.CLASS_NAMES): |
|
|
crop = model.CLASS_TO_CROP[idx] |
|
|
key = model.CLASS_TO_KEY[idx] |
|
|
print(f" {idx}: {name}") |
|
|
print(f" Crop: {crop}") |
|
|
print(f" Key: {key}") |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("✅ YOLOv11 model test completed!") |
|
|
print("=" * 60) |
|
|
|