""" Severity Classifier Module for CropDoctor-Semantic =================================================== This module provides a CNN-based classifier to assess the severity of plant diseases from segmented regions. Severity Levels: 0 - Healthy: No disease symptoms 1 - Mild: <10% affected area, early stage 2 - Moderate: 10-30% affected, established infection 3 - Severe: >30% affected, critical intervention needed """ import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader import torchvision.transforms as transforms from torchvision import models import numpy as np from PIL import Image from pathlib import Path from typing import Tuple, Dict, List, Optional, Union from dataclasses import dataclass import logging logger = logging.getLogger(__name__) @dataclass class SeverityPrediction: """Container for severity classification results.""" severity_level: int # 0-3 severity_label: str # "healthy", "mild", "moderate", "severe" confidence: float # 0-1 probabilities: Dict[str, float] # Per-class probabilities affected_area_percent: float # From mask analysis # Severity level mapping SEVERITY_LABELS = { 0: "healthy", 1: "mild", 2: "moderate", 3: "severe" } SEVERITY_DESCRIPTIONS = { 0: "No disease symptoms detected. Plant appears healthy.", 1: "Early stage infection. Less than 10% of tissue affected. Preventive action recommended.", 2: "Established infection. 10-30% of tissue affected. Treatment required.", 3: "Severe infection. Over 30% of tissue affected. Urgent intervention needed." } class SeverityClassifierCNN(nn.Module): """ CNN model for disease severity classification. Architecture options: - EfficientNet-B0 (lightweight, fast) - ResNet-50 (balanced) - ConvNeXt-Tiny (modern, accurate) """ def __init__( self, num_classes: int = 4, backbone: str = "efficientnet_b0", pretrained: bool = True, dropout: float = 0.3 ): super().__init__() self.num_classes = num_classes self.backbone_name = backbone # Load backbone if backbone == "efficientnet_b0": self.backbone = models.efficientnet_b0( weights=models.EfficientNet_B0_Weights.DEFAULT if pretrained else None ) in_features = self.backbone.classifier[1].in_features self.backbone.classifier = nn.Sequential( nn.Dropout(dropout), nn.Linear(in_features, num_classes) ) elif backbone == "resnet50": self.backbone = models.resnet50( weights=models.ResNet50_Weights.DEFAULT if pretrained else None ) in_features = self.backbone.fc.in_features self.backbone.fc = nn.Sequential( nn.Dropout(dropout), nn.Linear(in_features, num_classes) ) elif backbone == "convnext_tiny": self.backbone = models.convnext_tiny( weights=models.ConvNeXt_Tiny_Weights.DEFAULT if pretrained else None ) in_features = self.backbone.classifier[2].in_features self.backbone.classifier = nn.Sequential( nn.Flatten(1), nn.LayerNorm(in_features), nn.Dropout(dropout), nn.Linear(in_features, num_classes) ) else: raise ValueError(f"Unknown backbone: {backbone}") def forward(self, x: torch.Tensor) -> torch.Tensor: return self.backbone(x) def predict(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Get predictions with probabilities.""" logits = self.forward(x) probs = F.softmax(logits, dim=1) preds = torch.argmax(probs, dim=1) return preds, probs class SeverityClassifier: """ High-level interface for severity classification. Handles image preprocessing, model loading, and prediction formatting. Example: >>> classifier = SeverityClassifier("models/severity_classifier/best.pt") >>> result = classifier.classify("diseased_leaf.jpg") >>> print(f"Severity: {result.severity_label} ({result.confidence:.2f})") """ def __init__( self, checkpoint_path: Optional[str] = None, device: Optional[str] = None, image_size: int = 224 ): """ Initialize severity classifier. Args: checkpoint_path: Path to trained model checkpoint device: Device to use (auto-detected if None) image_size: Input image size for the model """ # Set device if device is None: self.device = "cuda" if torch.cuda.is_available() else "cpu" else: self.device = device self.image_size = image_size self.checkpoint_path = checkpoint_path # Initialize model self.model = None self._setup_transforms() def _setup_transforms(self): """Setup image preprocessing transforms.""" # ImageNet normalization self.transform = transforms.Compose([ transforms.Resize((self.image_size, self.image_size)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) # Augmentation for training self.train_transform = transforms.Compose([ transforms.RandomResizedCrop(self.image_size, scale=(0.8, 1.0)), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.RandomRotation(15), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) def load_model(self, backbone: str = "efficientnet_b0"): """Load or initialize the model.""" if self.model is not None: return self.model = SeverityClassifierCNN( num_classes=4, backbone=backbone, pretrained=True ) if self.checkpoint_path and Path(self.checkpoint_path).exists(): logger.info(f"Loading checkpoint from {self.checkpoint_path}") checkpoint = torch.load(self.checkpoint_path, map_location=self.device, weights_only=False) self.model.load_state_dict(checkpoint["model_state_dict"]) else: logger.warning("No checkpoint loaded, using pretrained backbone only") self.model.to(self.device) self.model.eval() def preprocess_image( self, image: Union[str, Path, Image.Image, np.ndarray] ) -> torch.Tensor: """Preprocess image for classification.""" if isinstance(image, (str, Path)): image = Image.open(image).convert("RGB") elif isinstance(image, np.ndarray): image = Image.fromarray(image) tensor = self.transform(image) return tensor.unsqueeze(0) # Add batch dimension def classify( self, image: Union[str, Path, Image.Image, np.ndarray], mask: Optional[np.ndarray] = None ) -> SeverityPrediction: """ Classify disease severity in an image. Args: image: Input image (path, PIL Image, or numpy array) mask: Optional binary mask of diseased region Returns: SeverityPrediction with severity level, confidence, and details """ self.load_model() # Calculate affected area from mask affected_percent = 0.0 if mask is not None: affected_percent = (mask.sum() / mask.size) * 100 # Preprocess and predict input_tensor = self.preprocess_image(image).to(self.device) with torch.no_grad(): pred, probs = self.model.predict(input_tensor) severity_level = pred.item() confidence = probs[0, severity_level].item() # Format probabilities prob_dict = { SEVERITY_LABELS[i]: probs[0, i].item() for i in range(4) } return SeverityPrediction( severity_level=severity_level, severity_label=SEVERITY_LABELS[severity_level], confidence=confidence, probabilities=prob_dict, affected_area_percent=affected_percent ) def classify_region( self, image: Union[str, Path, Image.Image, np.ndarray], mask: np.ndarray ) -> SeverityPrediction: """ Classify severity of a specific masked region. Extracts the bounding box of the mask and classifies that region. Args: image: Full image mask: Binary mask of region to classify Returns: SeverityPrediction for the masked region """ # Load image if needed if isinstance(image, (str, Path)): image = Image.open(image).convert("RGB") elif isinstance(image, np.ndarray): image = Image.fromarray(image) img_array = np.array(image) # Get bounding box from mask rows = np.any(mask, axis=1) cols = np.any(mask, axis=0) if not rows.any() or not cols.any(): # Empty mask, return healthy return SeverityPrediction( severity_level=0, severity_label="healthy", confidence=1.0, probabilities={"healthy": 1.0, "mild": 0.0, "moderate": 0.0, "severe": 0.0}, affected_area_percent=0.0 ) y_min, y_max = np.where(rows)[0][[0, -1]] x_min, x_max = np.where(cols)[0][[0, -1]] # Add padding pad = 10 y_min = max(0, y_min - pad) y_max = min(img_array.shape[0], y_max + pad) x_min = max(0, x_min - pad) x_max = min(img_array.shape[1], x_max + pad) # Crop region cropped = img_array[y_min:y_max, x_min:x_max] cropped_mask = mask[y_min:y_max, x_min:x_max] return self.classify(cropped, mask=cropped_mask) def classify_batch( self, images: List[Union[str, Path, Image.Image, np.ndarray]], masks: Optional[List[np.ndarray]] = None, batch_size: int = 16 ) -> List[SeverityPrediction]: """ Classify multiple images in batches. Args: images: List of images to classify masks: Optional list of masks for each image batch_size: Batch size for inference Returns: List of SeverityPrediction for each image """ self.load_model() results = [] for i in range(0, len(images), batch_size): batch_images = images[i:i + batch_size] batch_masks = masks[i:i + batch_size] if masks else [None] * len(batch_images) # Preprocess batch tensors = [self.preprocess_image(img) for img in batch_images] batch_tensor = torch.cat(tensors, dim=0).to(self.device) # Predict with torch.no_grad(): preds, probs = self.model.predict(batch_tensor) # Format results for j, (pred, prob) in enumerate(zip(preds, probs)): mask = batch_masks[j] affected_percent = 0.0 if mask is not None: affected_percent = (mask.sum() / mask.size) * 100 severity_level = pred.item() results.append(SeverityPrediction( severity_level=severity_level, severity_label=SEVERITY_LABELS[severity_level], confidence=prob[severity_level].item(), probabilities={ SEVERITY_LABELS[k]: prob[k].item() for k in range(4) }, affected_area_percent=affected_percent )) return results class PlantDiseaseDataset(Dataset): """ Dataset class for training severity classifier. Expected folder structure: data_root/ healthy/ image1.jpg image2.jpg mild/ ... moderate/ ... severe/ ... """ def __init__( self, data_root: str, transform: Optional[transforms.Compose] = None, split: str = "train" ): self.data_root = Path(data_root) self.transform = transform self.split = split # Collect image paths and labels self.samples = [] for label_idx, label_name in SEVERITY_LABELS.items(): label_dir = self.data_root / label_name if label_dir.exists(): for img_path in label_dir.glob("*.jpg"): self.samples.append((img_path, label_idx)) for img_path in label_dir.glob("*.png"): self.samples.append((img_path, label_idx)) logger.info(f"Loaded {len(self.samples)} samples for {split}") def __len__(self) -> int: return len(self.samples) def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]: img_path, label = self.samples[idx] image = Image.open(img_path).convert("RGB") if self.transform: image = self.transform(image) return image, label def train_classifier( train_data_root: str, val_data_root: str, output_dir: str, backbone: str = "efficientnet_b0", epochs: int = 50, batch_size: int = 32, learning_rate: float = 1e-4, device: str = "cuda" ): """ Train the severity classifier. Args: train_data_root: Path to training data val_data_root: Path to validation data output_dir: Where to save checkpoints backbone: Model backbone to use epochs: Number of training epochs batch_size: Training batch size learning_rate: Initial learning rate device: Device to train on """ output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) # Setup classifier for transforms classifier = SeverityClassifier() # Create datasets train_dataset = PlantDiseaseDataset( train_data_root, transform=classifier.train_transform, split="train" ) val_dataset = PlantDiseaseDataset( val_data_root, transform=classifier.transform, split="val" ) # Create dataloaders train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True ) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=4 ) # Initialize model model = SeverityClassifierCNN( num_classes=4, backbone=backbone, pretrained=True ).to(device) # Loss and optimizer criterion = nn.CrossEntropyLoss() optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs) best_val_acc = 0.0 for epoch in range(epochs): # Training model.train() train_loss = 0.0 train_correct = 0 train_total = 0 for images, labels in train_loader: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() train_loss += loss.item() _, predicted = outputs.max(1) train_total += labels.size(0) train_correct += predicted.eq(labels).sum().item() # Validation model.eval() val_loss = 0.0 val_correct = 0 val_total = 0 with torch.no_grad(): for images, labels in val_loader: images, labels = images.to(device), labels.to(device) outputs = model(images) loss = criterion(outputs, labels) val_loss += loss.item() _, predicted = outputs.max(1) val_total += labels.size(0) val_correct += predicted.eq(labels).sum().item() train_acc = 100. * train_correct / train_total val_acc = 100. * val_correct / val_total scheduler.step() logger.info( f"Epoch {epoch+1}/{epochs} - " f"Train Loss: {train_loss/len(train_loader):.4f}, Train Acc: {train_acc:.2f}% - " f"Val Loss: {val_loss/len(val_loader):.4f}, Val Acc: {val_acc:.2f}%" ) # Save best model if val_acc > best_val_acc: best_val_acc = val_acc torch.save({ "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "val_acc": val_acc, "backbone": backbone }, output_dir / "best.pt") logger.info(f"Saved best model with val_acc: {val_acc:.2f}%") # Save final model torch.save({ "epoch": epochs, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "val_acc": val_acc, "backbone": backbone }, output_dir / "final.pt") logger.info(f"Training complete. Best val_acc: {best_val_acc:.2f}%") if __name__ == "__main__": # Quick test classifier = SeverityClassifier() # Create a test image test_image = Image.new("RGB", (224, 224), color=(139, 69, 19)) # Brown # Test classification (will use random weights without checkpoint) result = classifier.classify(test_image) print(f"Severity: {result.severity_label}") print(f"Confidence: {result.confidence:.2f}") print(f"Probabilities: {result.probabilities}") print(f"Description: {SEVERITY_DESCRIPTIONS[result.severity_level]}")