Spaces:
Running
Running
| """ | |
| Severity Classifier Module for CropDoctor-Semantic | |
| =================================================== | |
| This module provides a CNN-based classifier to assess the severity | |
| of plant diseases from segmented regions. | |
| Severity Levels: | |
| 0 - Healthy: No disease symptoms | |
| 1 - Mild: <10% affected area, early stage | |
| 2 - Moderate: 10-30% affected, established infection | |
| 3 - Severe: >30% affected, critical intervention needed | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset, DataLoader | |
| import torchvision.transforms as transforms | |
| from torchvision import models | |
| import numpy as np | |
| from PIL import Image | |
| from pathlib import Path | |
| from typing import Tuple, Dict, List, Optional, Union | |
| from dataclasses import dataclass | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class SeverityPrediction: | |
| """Container for severity classification results.""" | |
| severity_level: int # 0-3 | |
| severity_label: str # "healthy", "mild", "moderate", "severe" | |
| confidence: float # 0-1 | |
| probabilities: Dict[str, float] # Per-class probabilities | |
| affected_area_percent: float # From mask analysis | |
| # Severity level mapping | |
| SEVERITY_LABELS = { | |
| 0: "healthy", | |
| 1: "mild", | |
| 2: "moderate", | |
| 3: "severe" | |
| } | |
| SEVERITY_DESCRIPTIONS = { | |
| 0: "No disease symptoms detected. Plant appears healthy.", | |
| 1: "Early stage infection. Less than 10% of tissue affected. Preventive action recommended.", | |
| 2: "Established infection. 10-30% of tissue affected. Treatment required.", | |
| 3: "Severe infection. Over 30% of tissue affected. Urgent intervention needed." | |
| } | |
| class SeverityClassifierCNN(nn.Module): | |
| """ | |
| CNN model for disease severity classification. | |
| Architecture options: | |
| - EfficientNet-B0 (lightweight, fast) | |
| - ResNet-50 (balanced) | |
| - ConvNeXt-Tiny (modern, accurate) | |
| """ | |
| def __init__( | |
| self, | |
| num_classes: int = 4, | |
| backbone: str = "efficientnet_b0", | |
| pretrained: bool = True, | |
| dropout: float = 0.3 | |
| ): | |
| super().__init__() | |
| self.num_classes = num_classes | |
| self.backbone_name = backbone | |
| # Load backbone | |
| if backbone == "efficientnet_b0": | |
| self.backbone = models.efficientnet_b0( | |
| weights=models.EfficientNet_B0_Weights.DEFAULT if pretrained else None | |
| ) | |
| in_features = self.backbone.classifier[1].in_features | |
| self.backbone.classifier = nn.Sequential( | |
| nn.Dropout(dropout), | |
| nn.Linear(in_features, num_classes) | |
| ) | |
| elif backbone == "resnet50": | |
| self.backbone = models.resnet50( | |
| weights=models.ResNet50_Weights.DEFAULT if pretrained else None | |
| ) | |
| in_features = self.backbone.fc.in_features | |
| self.backbone.fc = nn.Sequential( | |
| nn.Dropout(dropout), | |
| nn.Linear(in_features, num_classes) | |
| ) | |
| elif backbone == "convnext_tiny": | |
| self.backbone = models.convnext_tiny( | |
| weights=models.ConvNeXt_Tiny_Weights.DEFAULT if pretrained else None | |
| ) | |
| in_features = self.backbone.classifier[2].in_features | |
| self.backbone.classifier = nn.Sequential( | |
| nn.Flatten(1), | |
| nn.LayerNorm(in_features), | |
| nn.Dropout(dropout), | |
| nn.Linear(in_features, num_classes) | |
| ) | |
| else: | |
| raise ValueError(f"Unknown backbone: {backbone}") | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.backbone(x) | |
| def predict(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Get predictions with probabilities.""" | |
| logits = self.forward(x) | |
| probs = F.softmax(logits, dim=1) | |
| preds = torch.argmax(probs, dim=1) | |
| return preds, probs | |
| class SeverityClassifier: | |
| """ | |
| High-level interface for severity classification. | |
| Handles image preprocessing, model loading, and prediction formatting. | |
| Example: | |
| >>> classifier = SeverityClassifier("models/severity_classifier/best.pt") | |
| >>> result = classifier.classify("diseased_leaf.jpg") | |
| >>> print(f"Severity: {result.severity_label} ({result.confidence:.2f})") | |
| """ | |
| def __init__( | |
| self, | |
| checkpoint_path: Optional[str] = None, | |
| device: Optional[str] = None, | |
| image_size: int = 224 | |
| ): | |
| """ | |
| Initialize severity classifier. | |
| Args: | |
| checkpoint_path: Path to trained model checkpoint | |
| device: Device to use (auto-detected if None) | |
| image_size: Input image size for the model | |
| """ | |
| # Set device | |
| if device is None: | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| else: | |
| self.device = device | |
| self.image_size = image_size | |
| self.checkpoint_path = checkpoint_path | |
| # Initialize model | |
| self.model = None | |
| self._setup_transforms() | |
| def _setup_transforms(self): | |
| """Setup image preprocessing transforms.""" | |
| # ImageNet normalization | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((self.image_size, self.image_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ) | |
| ]) | |
| # Augmentation for training | |
| self.train_transform = transforms.Compose([ | |
| transforms.RandomResizedCrop(self.image_size, scale=(0.8, 1.0)), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.RandomVerticalFlip(), | |
| transforms.RandomRotation(15), | |
| transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ) | |
| ]) | |
| def load_model(self, backbone: str = "efficientnet_b0"): | |
| """Load or initialize the model.""" | |
| if self.model is not None: | |
| return | |
| self.model = SeverityClassifierCNN( | |
| num_classes=4, | |
| backbone=backbone, | |
| pretrained=True | |
| ) | |
| if self.checkpoint_path and Path(self.checkpoint_path).exists(): | |
| logger.info(f"Loading checkpoint from {self.checkpoint_path}") | |
| checkpoint = torch.load(self.checkpoint_path, map_location=self.device, weights_only=False) | |
| self.model.load_state_dict(checkpoint["model_state_dict"]) | |
| else: | |
| logger.warning("No checkpoint loaded, using pretrained backbone only") | |
| self.model.to(self.device) | |
| self.model.eval() | |
| def preprocess_image( | |
| self, | |
| image: Union[str, Path, Image.Image, np.ndarray] | |
| ) -> torch.Tensor: | |
| """Preprocess image for classification.""" | |
| if isinstance(image, (str, Path)): | |
| image = Image.open(image).convert("RGB") | |
| elif isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| tensor = self.transform(image) | |
| return tensor.unsqueeze(0) # Add batch dimension | |
| def classify( | |
| self, | |
| image: Union[str, Path, Image.Image, np.ndarray], | |
| mask: Optional[np.ndarray] = None | |
| ) -> SeverityPrediction: | |
| """ | |
| Classify disease severity in an image. | |
| Args: | |
| image: Input image (path, PIL Image, or numpy array) | |
| mask: Optional binary mask of diseased region | |
| Returns: | |
| SeverityPrediction with severity level, confidence, and details | |
| """ | |
| self.load_model() | |
| # Calculate affected area from mask | |
| affected_percent = 0.0 | |
| if mask is not None: | |
| affected_percent = (mask.sum() / mask.size) * 100 | |
| # Preprocess and predict | |
| input_tensor = self.preprocess_image(image).to(self.device) | |
| with torch.no_grad(): | |
| pred, probs = self.model.predict(input_tensor) | |
| severity_level = pred.item() | |
| confidence = probs[0, severity_level].item() | |
| # Format probabilities | |
| prob_dict = { | |
| SEVERITY_LABELS[i]: probs[0, i].item() | |
| for i in range(4) | |
| } | |
| return SeverityPrediction( | |
| severity_level=severity_level, | |
| severity_label=SEVERITY_LABELS[severity_level], | |
| confidence=confidence, | |
| probabilities=prob_dict, | |
| affected_area_percent=affected_percent | |
| ) | |
| def classify_region( | |
| self, | |
| image: Union[str, Path, Image.Image, np.ndarray], | |
| mask: np.ndarray | |
| ) -> SeverityPrediction: | |
| """ | |
| Classify severity of a specific masked region. | |
| Extracts the bounding box of the mask and classifies that region. | |
| Args: | |
| image: Full image | |
| mask: Binary mask of region to classify | |
| Returns: | |
| SeverityPrediction for the masked region | |
| """ | |
| # Load image if needed | |
| if isinstance(image, (str, Path)): | |
| image = Image.open(image).convert("RGB") | |
| elif isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| img_array = np.array(image) | |
| # Get bounding box from mask | |
| rows = np.any(mask, axis=1) | |
| cols = np.any(mask, axis=0) | |
| if not rows.any() or not cols.any(): | |
| # Empty mask, return healthy | |
| return SeverityPrediction( | |
| severity_level=0, | |
| severity_label="healthy", | |
| confidence=1.0, | |
| probabilities={"healthy": 1.0, "mild": 0.0, "moderate": 0.0, "severe": 0.0}, | |
| affected_area_percent=0.0 | |
| ) | |
| y_min, y_max = np.where(rows)[0][[0, -1]] | |
| x_min, x_max = np.where(cols)[0][[0, -1]] | |
| # Add padding | |
| pad = 10 | |
| y_min = max(0, y_min - pad) | |
| y_max = min(img_array.shape[0], y_max + pad) | |
| x_min = max(0, x_min - pad) | |
| x_max = min(img_array.shape[1], x_max + pad) | |
| # Crop region | |
| cropped = img_array[y_min:y_max, x_min:x_max] | |
| cropped_mask = mask[y_min:y_max, x_min:x_max] | |
| return self.classify(cropped, mask=cropped_mask) | |
| def classify_batch( | |
| self, | |
| images: List[Union[str, Path, Image.Image, np.ndarray]], | |
| masks: Optional[List[np.ndarray]] = None, | |
| batch_size: int = 16 | |
| ) -> List[SeverityPrediction]: | |
| """ | |
| Classify multiple images in batches. | |
| Args: | |
| images: List of images to classify | |
| masks: Optional list of masks for each image | |
| batch_size: Batch size for inference | |
| Returns: | |
| List of SeverityPrediction for each image | |
| """ | |
| self.load_model() | |
| results = [] | |
| for i in range(0, len(images), batch_size): | |
| batch_images = images[i:i + batch_size] | |
| batch_masks = masks[i:i + batch_size] if masks else [None] * len(batch_images) | |
| # Preprocess batch | |
| tensors = [self.preprocess_image(img) for img in batch_images] | |
| batch_tensor = torch.cat(tensors, dim=0).to(self.device) | |
| # Predict | |
| with torch.no_grad(): | |
| preds, probs = self.model.predict(batch_tensor) | |
| # Format results | |
| for j, (pred, prob) in enumerate(zip(preds, probs)): | |
| mask = batch_masks[j] | |
| affected_percent = 0.0 | |
| if mask is not None: | |
| affected_percent = (mask.sum() / mask.size) * 100 | |
| severity_level = pred.item() | |
| results.append(SeverityPrediction( | |
| severity_level=severity_level, | |
| severity_label=SEVERITY_LABELS[severity_level], | |
| confidence=prob[severity_level].item(), | |
| probabilities={ | |
| SEVERITY_LABELS[k]: prob[k].item() | |
| for k in range(4) | |
| }, | |
| affected_area_percent=affected_percent | |
| )) | |
| return results | |
| class PlantDiseaseDataset(Dataset): | |
| """ | |
| Dataset class for training severity classifier. | |
| Expected folder structure: | |
| data_root/ | |
| healthy/ | |
| image1.jpg | |
| image2.jpg | |
| mild/ | |
| ... | |
| moderate/ | |
| ... | |
| severe/ | |
| ... | |
| """ | |
| def __init__( | |
| self, | |
| data_root: str, | |
| transform: Optional[transforms.Compose] = None, | |
| split: str = "train" | |
| ): | |
| self.data_root = Path(data_root) | |
| self.transform = transform | |
| self.split = split | |
| # Collect image paths and labels | |
| self.samples = [] | |
| for label_idx, label_name in SEVERITY_LABELS.items(): | |
| label_dir = self.data_root / label_name | |
| if label_dir.exists(): | |
| for img_path in label_dir.glob("*.jpg"): | |
| self.samples.append((img_path, label_idx)) | |
| for img_path in label_dir.glob("*.png"): | |
| self.samples.append((img_path, label_idx)) | |
| logger.info(f"Loaded {len(self.samples)} samples for {split}") | |
| def __len__(self) -> int: | |
| return len(self.samples) | |
| def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]: | |
| img_path, label = self.samples[idx] | |
| image = Image.open(img_path).convert("RGB") | |
| if self.transform: | |
| image = self.transform(image) | |
| return image, label | |
| def train_classifier( | |
| train_data_root: str, | |
| val_data_root: str, | |
| output_dir: str, | |
| backbone: str = "efficientnet_b0", | |
| epochs: int = 50, | |
| batch_size: int = 32, | |
| learning_rate: float = 1e-4, | |
| device: str = "cuda" | |
| ): | |
| """ | |
| Train the severity classifier. | |
| Args: | |
| train_data_root: Path to training data | |
| val_data_root: Path to validation data | |
| output_dir: Where to save checkpoints | |
| backbone: Model backbone to use | |
| epochs: Number of training epochs | |
| batch_size: Training batch size | |
| learning_rate: Initial learning rate | |
| device: Device to train on | |
| """ | |
| output_dir = Path(output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| # Setup classifier for transforms | |
| classifier = SeverityClassifier() | |
| # Create datasets | |
| train_dataset = PlantDiseaseDataset( | |
| train_data_root, | |
| transform=classifier.train_transform, | |
| split="train" | |
| ) | |
| val_dataset = PlantDiseaseDataset( | |
| val_data_root, | |
| transform=classifier.transform, | |
| split="val" | |
| ) | |
| # Create dataloaders | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=4, | |
| pin_memory=True | |
| ) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| num_workers=4 | |
| ) | |
| # Initialize model | |
| model = SeverityClassifierCNN( | |
| num_classes=4, | |
| backbone=backbone, | |
| pretrained=True | |
| ).to(device) | |
| # Loss and optimizer | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs) | |
| best_val_acc = 0.0 | |
| for epoch in range(epochs): | |
| # Training | |
| model.train() | |
| train_loss = 0.0 | |
| train_correct = 0 | |
| train_total = 0 | |
| for images, labels in train_loader: | |
| images, labels = images.to(device), labels.to(device) | |
| optimizer.zero_grad() | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| train_loss += loss.item() | |
| _, predicted = outputs.max(1) | |
| train_total += labels.size(0) | |
| train_correct += predicted.eq(labels).sum().item() | |
| # Validation | |
| model.eval() | |
| val_loss = 0.0 | |
| val_correct = 0 | |
| val_total = 0 | |
| with torch.no_grad(): | |
| for images, labels in val_loader: | |
| images, labels = images.to(device), labels.to(device) | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| val_loss += loss.item() | |
| _, predicted = outputs.max(1) | |
| val_total += labels.size(0) | |
| val_correct += predicted.eq(labels).sum().item() | |
| train_acc = 100. * train_correct / train_total | |
| val_acc = 100. * val_correct / val_total | |
| scheduler.step() | |
| logger.info( | |
| f"Epoch {epoch+1}/{epochs} - " | |
| f"Train Loss: {train_loss/len(train_loader):.4f}, Train Acc: {train_acc:.2f}% - " | |
| f"Val Loss: {val_loss/len(val_loader):.4f}, Val Acc: {val_acc:.2f}%" | |
| ) | |
| # Save best model | |
| if val_acc > best_val_acc: | |
| best_val_acc = val_acc | |
| torch.save({ | |
| "epoch": epoch, | |
| "model_state_dict": model.state_dict(), | |
| "optimizer_state_dict": optimizer.state_dict(), | |
| "val_acc": val_acc, | |
| "backbone": backbone | |
| }, output_dir / "best.pt") | |
| logger.info(f"Saved best model with val_acc: {val_acc:.2f}%") | |
| # Save final model | |
| torch.save({ | |
| "epoch": epochs, | |
| "model_state_dict": model.state_dict(), | |
| "optimizer_state_dict": optimizer.state_dict(), | |
| "val_acc": val_acc, | |
| "backbone": backbone | |
| }, output_dir / "final.pt") | |
| logger.info(f"Training complete. Best val_acc: {best_val_acc:.2f}%") | |
| if __name__ == "__main__": | |
| # Quick test | |
| classifier = SeverityClassifier() | |
| # Create a test image | |
| test_image = Image.new("RGB", (224, 224), color=(139, 69, 19)) # Brown | |
| # Test classification (will use random weights without checkpoint) | |
| result = classifier.classify(test_image) | |
| print(f"Severity: {result.severity_label}") | |
| print(f"Confidence: {result.confidence:.2f}") | |
| print(f"Probabilities: {result.probabilities}") | |
| print(f"Description: {SEVERITY_DESCRIPTIONS[result.severity_level]}") | |