""" Plant Disease Classifier ========================= Classifies plant leaf diseases using MobileNetV2. Model: linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification - 38 classes (26 diseases + 12 healthy plants) - 99.47% accuracy on PlantVillage dataset - Input: 224x224 RGB image Usage: from src.classifier import predict result = predict(pil_image) print(result["prediction"]) # "Tomato - Late Blight" """ import torch from PIL import Image from transformers import AutoImageProcessor, AutoModelForImageClassification # ============================================================ # CONFIGURATION # ============================================================ MODEL_NAME = "linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification" # ============================================================ # MODULE STATE # ============================================================ _model = None _processor = None _device = None # ============================================================ # PRIVATE FUNCTIONS # ============================================================ def _load_model(): """ Load model and processor from HuggingFace. Executes only ONCE (lazy loading). Subsequent calls return cached objects. Returns: tuple: (model, processor, device) """ global _model, _processor, _device # Return cached if already loaded if _model is not None: return _model, _processor, _device print("๐ŸŒฑ Loading classification model...") # Determine device (GPU or CPU) _device = "cuda" if torch.cuda.is_available() else "cpu" print(f" Device: {_device}") # Load processor (prepares images for model) _processor = AutoImageProcessor.from_pretrained(MODEL_NAME) # Load model _model = AutoModelForImageClassification.from_pretrained(MODEL_NAME) _model.to(_device) _model.eval() # Set to evaluation mode print(f"โœ… Model loaded: {len(_model.config.id2label)} classes") return _model, _processor, _device def _parse_label(raw_label: str) -> tuple: """ Parse raw model label into (plant, disease). Args: raw_label: Model label, e.g. "Tomato___Late_blight" Returns: tuple: (plant, disease) e.g. ("Tomato", "Late blight") """ try: # Split by triple underscore parts = raw_label.split("___") plant = parts[0].replace("_", " ").replace("(", "").replace(")", "").strip() if len(parts) > 1: disease = parts[1].replace("_", " ").strip() # Capitalize properly disease = disease.title() if disease.lower() != "healthy" else "Healthy" else: disease = "Unknown" return (plant, disease) except: return (raw_label, "Unknown") # ============================================================ # PUBLIC FUNCTION # ============================================================ def predict(image: Image.Image, top_k: int = 3) -> dict: """ Predict disease in a plant leaf image. Args: image: PIL Image (PIL.Image.Image) top_k: Number of alternative predictions to return Returns: dict with result: { "success": True, "prediction": "Tomato - Late Blight", "confidence": 95.23, "is_healthy": False, "plant": "Tomato", "disease": "Late Blight", "raw_label": "Tomato___Late_blight", "top_k": [ {"plant": "Tomato", "disease": "Late Blight", "confidence": 95.23}, ... ] } On error: { "success": False, "error": "Error description" } """ # Validate input if image is None: return { "success": False, "error": "No image provided" } if not isinstance(image, Image.Image): return { "success": False, "error": f"Invalid image type: {type(image)}. Expected PIL.Image" } try: # Load model (only first time) model, processor, device = _load_model() # Preprocess image image = image.convert("RGB") inputs = processor(images=image, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} # Inference with torch.no_grad(): outputs = model(**inputs) # Process results logits = outputs.logits probs = torch.nn.functional.softmax(logits, dim=-1) # Get top prediction top_prob, top_idx = torch.max(probs, dim=-1) raw_label = model.config.id2label[top_idx.item()] confidence = round(top_prob.item() * 100, 2) # Parse label plant, disease = _parse_label(raw_label) is_healthy = "healthy" in raw_label.lower() # Get top-k predictions top_k_probs, top_k_indices = torch.topk(probs, min(top_k, probs.shape[-1])) top_k_results = [] for idx, prob in zip(top_k_indices[0], top_k_probs[0]): label = model.config.id2label[idx.item()] p, d = _parse_label(label) top_k_results.append({ "plant": p, "disease": d, "confidence": round(prob.item() * 100, 2), "raw_label": label }) # Return structured result return { "success": True, "prediction": f"{plant} - {disease}", "confidence": confidence, "is_healthy": is_healthy, "plant": plant, "disease": disease, "raw_label": raw_label, "top_k": top_k_results } except Exception as e: return { "success": False, "error": str(e) } print("\n" + "="*50) print("๐Ÿงช CLASSIFIER TEST") print("="*50) model, processor, device = _load_model() print(f"\n๐Ÿ“Š Available classes: {len(model.config.id2label)}") print(f"๐Ÿ–ฅ๏ธ Device: {device}") print("\n๐Ÿ“‹ Sample classes:") for i, (idx, label) in enumerate(list(model.config.id2label.items())[:5]): plant, disease = _parse_label(label) print(f" {idx}: {plant} - {disease}") print("\nโœ… Classifier ready") print("="*50)