Spaces:
Running
Running
| """ | |
| Plant Disease Classifier | |
| ========================= | |
| Classifies plant leaf diseases using MobileNetV2. | |
| Model: linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification | |
| - 38 classes (26 diseases + 12 healthy plants) | |
| - 99.47% accuracy on PlantVillage dataset | |
| - Input: 224x224 RGB image | |
| Usage: | |
| from src.classifier import predict | |
| result = predict(pil_image) | |
| print(result["prediction"]) # "Tomato - Late Blight" | |
| """ | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| # ============================================================ | |
| # CONFIGURATION | |
| # ============================================================ | |
| MODEL_NAME = "linkanjarad/mobilenet_v2_1.0_224-plant-disease-identification" | |
| # ============================================================ | |
| # MODULE STATE | |
| # ============================================================ | |
| _model = None | |
| _processor = None | |
| _device = None | |
| # ============================================================ | |
| # PRIVATE FUNCTIONS | |
| # ============================================================ | |
| def _load_model(): | |
| """ | |
| Load model and processor from HuggingFace. | |
| Executes only ONCE (lazy loading). | |
| Subsequent calls return cached objects. | |
| Returns: | |
| tuple: (model, processor, device) | |
| """ | |
| global _model, _processor, _device | |
| # Return cached if already loaded | |
| if _model is not None: | |
| return _model, _processor, _device | |
| print("๐ฑ Loading classification model...") | |
| # Determine device (GPU or CPU) | |
| _device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f" Device: {_device}") | |
| # Load processor (prepares images for model) | |
| _processor = AutoImageProcessor.from_pretrained(MODEL_NAME) | |
| # Load model | |
| _model = AutoModelForImageClassification.from_pretrained(MODEL_NAME) | |
| _model.to(_device) | |
| _model.eval() # Set to evaluation mode | |
| print(f"โ Model loaded: {len(_model.config.id2label)} classes") | |
| return _model, _processor, _device | |
| def _parse_label(raw_label: str) -> tuple: | |
| """ | |
| Parse raw model label into (plant, disease). | |
| Args: | |
| raw_label: Model label, e.g. "Tomato___Late_blight" | |
| Returns: | |
| tuple: (plant, disease) | |
| e.g. ("Tomato", "Late blight") | |
| """ | |
| try: | |
| # Split by triple underscore | |
| parts = raw_label.split("___") | |
| plant = parts[0].replace("_", " ").replace("(", "").replace(")", "").strip() | |
| if len(parts) > 1: | |
| disease = parts[1].replace("_", " ").strip() | |
| # Capitalize properly | |
| disease = disease.title() if disease.lower() != "healthy" else "Healthy" | |
| else: | |
| disease = "Unknown" | |
| return (plant, disease) | |
| except: | |
| return (raw_label, "Unknown") | |
| # ============================================================ | |
| # PUBLIC FUNCTION | |
| # ============================================================ | |
| def predict(image: Image.Image, top_k: int = 3) -> dict: | |
| """ | |
| Predict disease in a plant leaf image. | |
| Args: | |
| image: PIL Image (PIL.Image.Image) | |
| top_k: Number of alternative predictions to return | |
| Returns: | |
| dict with result: | |
| { | |
| "success": True, | |
| "prediction": "Tomato - Late Blight", | |
| "confidence": 95.23, | |
| "is_healthy": False, | |
| "plant": "Tomato", | |
| "disease": "Late Blight", | |
| "raw_label": "Tomato___Late_blight", | |
| "top_k": [ | |
| {"plant": "Tomato", "disease": "Late Blight", "confidence": 95.23}, | |
| ... | |
| ] | |
| } | |
| On error: | |
| { | |
| "success": False, | |
| "error": "Error description" | |
| } | |
| """ | |
| # Validate input | |
| if image is None: | |
| return { | |
| "success": False, | |
| "error": "No image provided" | |
| } | |
| if not isinstance(image, Image.Image): | |
| return { | |
| "success": False, | |
| "error": f"Invalid image type: {type(image)}. Expected PIL.Image" | |
| } | |
| try: | |
| # Load model (only first time) | |
| model, processor, device = _load_model() | |
| # Preprocess image | |
| image = image.convert("RGB") | |
| inputs = processor(images=image, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # Inference | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # Process results | |
| logits = outputs.logits | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| # Get top prediction | |
| top_prob, top_idx = torch.max(probs, dim=-1) | |
| raw_label = model.config.id2label[top_idx.item()] | |
| confidence = round(top_prob.item() * 100, 2) | |
| # Parse label | |
| plant, disease = _parse_label(raw_label) | |
| is_healthy = "healthy" in raw_label.lower() | |
| # Get top-k predictions | |
| top_k_probs, top_k_indices = torch.topk(probs, min(top_k, probs.shape[-1])) | |
| top_k_results = [] | |
| for idx, prob in zip(top_k_indices[0], top_k_probs[0]): | |
| label = model.config.id2label[idx.item()] | |
| p, d = _parse_label(label) | |
| top_k_results.append({ | |
| "plant": p, | |
| "disease": d, | |
| "confidence": round(prob.item() * 100, 2), | |
| "raw_label": label | |
| }) | |
| # Return structured result | |
| return { | |
| "success": True, | |
| "prediction": f"{plant} - {disease}", | |
| "confidence": confidence, | |
| "is_healthy": is_healthy, | |
| "plant": plant, | |
| "disease": disease, | |
| "raw_label": raw_label, | |
| "top_k": top_k_results | |
| } | |
| except Exception as e: | |
| return { | |
| "success": False, | |
| "error": str(e) | |
| } | |
| print("\n" + "="*50) | |
| print("๐งช CLASSIFIER TEST") | |
| print("="*50) | |
| model, processor, device = _load_model() | |
| print(f"\n๐ Available classes: {len(model.config.id2label)}") | |
| print(f"๐ฅ๏ธ Device: {device}") | |
| print("\n๐ Sample classes:") | |
| for i, (idx, label) in enumerate(list(model.config.id2label.items())[:5]): | |
| plant, disease = _parse_label(label) | |
| print(f" {idx}: {plant} - {disease}") | |
| print("\nโ Classifier ready") | |
| print("="*50) |