Spaces:
Sleeping
Sleeping
| """ | |
| CropDoctor Pipeline Module | |
| ========================== | |
| This module integrates all components into a unified diagnostic pipeline: | |
| 1. SAM 3 Segmentation - Detect disease regions | |
| 2. Severity Classification - Assess severity level | |
| 3. Treatment Recommendation - Generate actionable advice | |
| Provides both single-image and batch processing capabilities. | |
| """ | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Union, Tuple | |
| from dataclasses import dataclass, asdict | |
| import json | |
| import csv | |
| from datetime import datetime | |
| import logging | |
| from .sam3_segmentation import SAM3Segmenter, create_segmenter, SegmentationResult | |
| from .severity_classifier import SeverityClassifier, SeverityPrediction, SEVERITY_LABELS | |
| from .treatment_recommender import TreatmentRecommender, TreatmentRecommendation | |
| from .leaf_segmenter import SAM2LeafSegmenter | |
| logger = logging.getLogger(__name__) | |
| class DiagnosticResult: | |
| """Complete diagnostic result for an image.""" | |
| # Image info | |
| image_path: str | |
| timestamp: str | |
| # Segmentation results | |
| num_regions_detected: int | |
| affected_area_percent: float | |
| detected_symptoms: List[str] | |
| # Severity assessment | |
| severity_level: int | |
| severity_label: str | |
| severity_confidence: float | |
| # Treatment recommendations | |
| disease_name: str | |
| disease_type: str | |
| disease_confidence: float | |
| organic_treatments: List[str] | |
| chemical_treatments: List[str] | |
| preventive_measures: List[str] | |
| treatment_timing: str | |
| urgency: str | |
| # Raw data for further analysis | |
| segmentation_masks: Optional[np.ndarray] = None | |
| segmentation_scores: Optional[np.ndarray] = None | |
| class CropDoctorPipeline: | |
| """ | |
| End-to-end pipeline for plant disease diagnosis. | |
| Integrates SAM 3 segmentation, severity classification, and | |
| LLM-based treatment recommendations into a single workflow. | |
| Example: | |
| >>> pipeline = CropDoctorPipeline() | |
| >>> result = pipeline.diagnose("path/to/leaf.jpg") | |
| >>> print(f"Disease: {result.disease_name}") | |
| >>> print(f"Severity: {result.severity_label}") | |
| >>> print(f"Treatment: {result.organic_treatments[0]}") | |
| """ | |
| def __init__( | |
| self, | |
| sam3_checkpoint: str = "models/sam3/sam3.pt", | |
| sam3_config: str = "configs/sam3_config.yaml", | |
| classifier_checkpoint: Optional[str] = None, | |
| use_llm: bool = True, | |
| anthropic_api_key: Optional[str] = None, | |
| device: Optional[str] = None, | |
| use_mock_sam3: bool = False, | |
| use_rfdetr: bool = False, | |
| rfdetr_checkpoint: str = "models/rfdetr/checkpoint_best_total.pth", | |
| rfdetr_model_size: str = "medium", | |
| use_leaf_segmentation: bool = False, | |
| sam2_checkpoint: str = "models/sam2/sam2.1_hiera_small.pt" | |
| ): | |
| """ | |
| Initialize the CropDoctor pipeline. | |
| Args: | |
| sam3_checkpoint: Path to SAM 3 model checkpoint | |
| sam3_config: Path to SAM 3 configuration | |
| classifier_checkpoint: Path to severity classifier checkpoint | |
| use_llm: Whether to use Claude API for recommendations | |
| anthropic_api_key: Optional API key for Claude | |
| device: Device to use (auto-detected if None) | |
| use_mock_sam3: Use mock SAM 3 for testing without model | |
| use_rfdetr: Use RF-DETR for detection (recommended) | |
| rfdetr_checkpoint: Path to trained RF-DETR model | |
| rfdetr_model_size: RF-DETR model size (nano, small, medium, base) | |
| use_leaf_segmentation: Use SAM2 to segment leaf before detection | |
| sam2_checkpoint: Path to SAM2 checkpoint for leaf segmentation | |
| """ | |
| # Set device | |
| if device is None: | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| else: | |
| self.device = device | |
| logger.info(f"Initializing CropDoctor Pipeline on {self.device}") | |
| # Initialize components (lazy loading) | |
| self._segmenter = None | |
| self._classifier = None | |
| self._recommender = None | |
| self._leaf_segmenter = None | |
| # Store config | |
| self.sam3_checkpoint = sam3_checkpoint | |
| self.sam3_config = sam3_config | |
| self.classifier_checkpoint = classifier_checkpoint | |
| self.use_llm = use_llm | |
| self.anthropic_api_key = anthropic_api_key | |
| self.use_mock_sam3 = use_mock_sam3 | |
| self.use_rfdetr = use_rfdetr | |
| self.rfdetr_checkpoint = rfdetr_checkpoint | |
| self.rfdetr_model_size = rfdetr_model_size | |
| self.use_leaf_segmentation = use_leaf_segmentation | |
| self.sam2_checkpoint = sam2_checkpoint | |
| # Default prompts for disease detection | |
| self.disease_prompts = [ | |
| "diseased plant tissue", | |
| "leaf with brown spots", | |
| "powdery mildew", | |
| "yellowing leaves", | |
| "pest damage", | |
| "wilted tissue", | |
| "healthy green leaf" | |
| ] | |
| def segmenter(self) -> SAM3Segmenter: | |
| """Lazy load segmenter (SAM 3, Mock, or RF-DETR).""" | |
| if self._segmenter is None: | |
| self._segmenter = create_segmenter( | |
| self.sam3_checkpoint, | |
| self.sam3_config, | |
| use_mock=self.use_mock_sam3, | |
| use_rfdetr=self.use_rfdetr, | |
| rfdetr_checkpoint=self.rfdetr_checkpoint, | |
| rfdetr_model_size=self.rfdetr_model_size | |
| ) | |
| return self._segmenter | |
| def classifier(self) -> SeverityClassifier: | |
| """Lazy load severity classifier.""" | |
| if self._classifier is None: | |
| self._classifier = SeverityClassifier( | |
| checkpoint_path=self.classifier_checkpoint, | |
| device=self.device | |
| ) | |
| return self._classifier | |
| def recommender(self) -> TreatmentRecommender: | |
| """Lazy load treatment recommender.""" | |
| if self._recommender is None: | |
| self._recommender = TreatmentRecommender( | |
| api_key=self.anthropic_api_key, | |
| use_llm=self.use_llm | |
| ) | |
| return self._recommender | |
| def leaf_segmenter(self) -> SAM2LeafSegmenter: | |
| """Lazy load SAM2 leaf segmenter.""" | |
| if self._leaf_segmenter is None: | |
| self._leaf_segmenter = SAM2LeafSegmenter( | |
| checkpoint_path=self.sam2_checkpoint, | |
| device=self.device | |
| ) | |
| return self._leaf_segmenter | |
| def diagnose( | |
| self, | |
| image: Union[str, Path, Image.Image, np.ndarray], | |
| plant_species: Optional[str] = None, | |
| analysis_profile: str = "standard", | |
| custom_prompts: Optional[List[str]] = None, | |
| return_masks: bool = False | |
| ) -> DiagnosticResult: | |
| """ | |
| Perform complete diagnosis on an image. | |
| Args: | |
| image: Input image (path, PIL Image, or numpy array) | |
| plant_species: Optional plant species for context | |
| analysis_profile: SAM 3 analysis profile | |
| custom_prompts: Optional custom prompts for segmentation | |
| return_masks: Whether to include segmentation masks in result | |
| Returns: | |
| DiagnosticResult with complete diagnosis | |
| """ | |
| timestamp = datetime.now().isoformat() | |
| image_path = str(image) if isinstance(image, (str, Path)) else "in_memory" | |
| logger.info(f"Starting diagnosis for {image_path}") | |
| # Load image if needed | |
| if isinstance(image, (str, Path)): | |
| pil_image = Image.open(image).convert("RGB") | |
| elif isinstance(image, np.ndarray): | |
| pil_image = Image.fromarray(image) | |
| else: | |
| pil_image = image | |
| # Store original image for visualization | |
| original_image = pil_image | |
| leaf_mask = None | |
| # Step 0 (optional): Leaf segmentation with SAM2 | |
| if self.use_leaf_segmentation: | |
| logger.info("Step 0: Leaf segmentation with SAM2") | |
| pil_image, leaf_mask = self.leaf_segmenter.auto_segment_leaf( | |
| pil_image, | |
| return_mask=True | |
| ) | |
| logger.info("Leaf isolated from background") | |
| # Step 1: Disease Detection (SAM3/RF-DETR) | |
| logger.info("Step 1: Disease detection") | |
| prompts = custom_prompts or self.disease_prompts | |
| seg_result = self.segmenter.segment_with_concepts(pil_image, prompts) | |
| # Step 1.5: Refine bounding boxes to proper masks using SAM2 | |
| # This converts RF-DETR rectangular boxes to precise segmentation masks | |
| if self.use_rfdetr and len(seg_result.boxes) > 0: | |
| logger.info("Step 1.5: Refining detection boxes with SAM2") | |
| refined_masks = self.leaf_segmenter.refine_boxes_to_masks( | |
| pil_image, | |
| seg_result.boxes | |
| ) | |
| # Replace rectangular masks with refined masks | |
| seg_result = SegmentationResult( | |
| masks=refined_masks, | |
| boxes=seg_result.boxes, | |
| scores=seg_result.scores, | |
| prompts=seg_result.prompts, | |
| prompt_indices=seg_result.prompt_indices | |
| ) | |
| # Calculate affected area | |
| area_stats = self.segmenter.calculate_affected_area( | |
| seg_result, | |
| healthy_prompt_idx=prompts.index("healthy green leaf") if "healthy green leaf" in prompts else None | |
| ) | |
| # Get detected symptoms (prompts with detections) | |
| detected_symptoms = [] | |
| for prompt_idx in np.unique(seg_result.prompt_indices): | |
| if prompt_idx < len(prompts): | |
| prompt = prompts[prompt_idx] | |
| if prompt != "healthy green leaf": | |
| detected_symptoms.append(prompt) | |
| # Step 2: Severity Classification | |
| logger.info("Step 2: Severity classification") | |
| # Create combined disease mask for classification | |
| if len(seg_result.masks) > 0: | |
| # Combine all disease masks (excluding healthy) | |
| disease_mask = np.zeros(seg_result.masks[0].shape, dtype=bool) | |
| for i, mask in enumerate(seg_result.masks): | |
| prompt_idx = seg_result.prompt_indices[i] | |
| if prompt_idx < len(prompts) and prompts[prompt_idx] != "healthy green leaf": | |
| disease_mask |= mask | |
| else: | |
| disease_mask = None | |
| severity_result = self.classifier.classify(pil_image, mask=disease_mask) | |
| # Override severity based on affected area if needed | |
| if area_stats["total_affected_percent"] < 1: | |
| severity_result = SeverityPrediction( | |
| severity_level=0, | |
| severity_label="healthy", | |
| confidence=0.9, | |
| probabilities={"healthy": 0.9, "mild": 0.05, "moderate": 0.03, "severe": 0.02}, | |
| affected_area_percent=area_stats["total_affected_percent"] | |
| ) | |
| elif area_stats["total_affected_percent"] < 10 and severity_result.severity_level > 1: | |
| severity_result.severity_level = 1 | |
| severity_result.severity_label = "mild" | |
| # Step 3: Treatment Recommendations | |
| logger.info("Step 3: Generating treatment recommendations") | |
| if detected_symptoms and severity_result.severity_level > 0: | |
| treatment_result = self.recommender.get_recommendation( | |
| symptoms=detected_symptoms, | |
| severity=severity_result.severity_label, | |
| plant_species=plant_species, | |
| affected_area_percent=area_stats["total_affected_percent"] | |
| ) | |
| else: | |
| # Healthy plant - no treatment needed | |
| treatment_result = TreatmentRecommendation( | |
| disease_name="No Disease Detected", | |
| disease_type="healthy", | |
| confidence=0.9, | |
| symptoms_matched=[], | |
| organic_treatments=["Continue regular care"], | |
| chemical_treatments=[], | |
| preventive_measures=[ | |
| "Maintain good air circulation", | |
| "Water at soil level", | |
| "Monitor regularly for early symptoms" | |
| ], | |
| timing="Regular monitoring recommended", | |
| urgency="low", | |
| additional_notes="Plant appears healthy. Continue preventive care." | |
| ) | |
| # Compile final result | |
| result = DiagnosticResult( | |
| image_path=image_path, | |
| timestamp=timestamp, | |
| num_regions_detected=len(seg_result.masks), | |
| affected_area_percent=area_stats["total_affected_percent"], | |
| detected_symptoms=detected_symptoms, | |
| severity_level=severity_result.severity_level, | |
| severity_label=severity_result.severity_label, | |
| severity_confidence=severity_result.confidence, | |
| disease_name=treatment_result.disease_name, | |
| disease_type=treatment_result.disease_type, | |
| disease_confidence=treatment_result.confidence, | |
| organic_treatments=treatment_result.organic_treatments, | |
| chemical_treatments=treatment_result.chemical_treatments, | |
| preventive_measures=treatment_result.preventive_measures, | |
| treatment_timing=treatment_result.timing, | |
| urgency=treatment_result.urgency, | |
| segmentation_masks=seg_result.masks if return_masks else None, | |
| segmentation_scores=seg_result.scores if return_masks else None | |
| ) | |
| logger.info(f"Diagnosis complete: {result.disease_name} ({result.severity_label})") | |
| return result | |
| def batch_diagnose( | |
| self, | |
| image_folder: Union[str, Path], | |
| output_dir: Optional[Union[str, Path]] = None, | |
| plant_species: Optional[str] = None, | |
| save_visualizations: bool = True, | |
| file_extensions: List[str] = [".jpg", ".jpeg", ".png"] | |
| ) -> List[DiagnosticResult]: | |
| """ | |
| Process multiple images from a folder. | |
| Args: | |
| image_folder: Path to folder containing images | |
| output_dir: Where to save results (optional) | |
| plant_species: Plant species for all images | |
| save_visualizations: Whether to save visualization images | |
| file_extensions: Image file extensions to process | |
| Returns: | |
| List of DiagnosticResult for each image | |
| """ | |
| image_folder = Path(image_folder) | |
| if output_dir: | |
| output_dir = Path(output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| # Find all images | |
| images = [] | |
| for ext in file_extensions: | |
| images.extend(image_folder.glob(f"*{ext}")) | |
| images.extend(image_folder.glob(f"*{ext.upper()}")) | |
| logger.info(f"Found {len(images)} images to process") | |
| results = [] | |
| for i, img_path in enumerate(images): | |
| logger.info(f"Processing {i+1}/{len(images)}: {img_path.name}") | |
| try: | |
| result = self.diagnose( | |
| img_path, | |
| plant_species=plant_species, | |
| return_masks=save_visualizations | |
| ) | |
| results.append(result) | |
| # Save visualization if requested | |
| if save_visualizations and output_dir: | |
| self._save_visualization( | |
| img_path, | |
| result, | |
| output_dir / f"{img_path.stem}_diagnosis.png" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error processing {img_path}: {e}") | |
| continue | |
| logger.info(f"Batch processing complete: {len(results)}/{len(images)} successful") | |
| return results | |
| def _save_visualization( | |
| self, | |
| image_path: Path, | |
| result: DiagnosticResult, | |
| output_path: Path | |
| ): | |
| """Save diagnostic visualization.""" | |
| # Import visualization module | |
| from .visualization import create_diagnostic_visualization | |
| image = Image.open(image_path) | |
| fig = create_diagnostic_visualization( | |
| image, | |
| result.segmentation_masks, | |
| result.severity_label, | |
| result.disease_name, | |
| result.affected_area_percent | |
| ) | |
| fig.savefig(output_path, dpi=150, bbox_inches='tight') | |
| import matplotlib.pyplot as plt | |
| plt.close(fig) | |
| def export_report( | |
| self, | |
| results: List[DiagnosticResult], | |
| output_path: Union[str, Path], | |
| format: str = "csv" | |
| ): | |
| """ | |
| Export results to file. | |
| Args: | |
| results: List of diagnostic results | |
| output_path: Output file path | |
| format: Output format ("csv", "json") | |
| """ | |
| output_path = Path(output_path) | |
| if format == "csv": | |
| self._export_csv(results, output_path) | |
| elif format == "json": | |
| self._export_json(results, output_path) | |
| else: | |
| raise ValueError(f"Unknown format: {format}") | |
| logger.info(f"Report exported to {output_path}") | |
| def _export_csv(self, results: List[DiagnosticResult], output_path: Path): | |
| """Export to CSV.""" | |
| with open(output_path, 'w', newline='') as f: | |
| if not results: | |
| return | |
| # Get fields (excluding numpy arrays) | |
| fields = [ | |
| 'image_path', 'timestamp', 'num_regions_detected', | |
| 'affected_area_percent', 'detected_symptoms', | |
| 'severity_level', 'severity_label', 'severity_confidence', | |
| 'disease_name', 'disease_type', 'disease_confidence', | |
| 'organic_treatments', 'urgency' | |
| ] | |
| writer = csv.DictWriter(f, fieldnames=fields) | |
| writer.writeheader() | |
| for result in results: | |
| row = { | |
| 'image_path': result.image_path, | |
| 'timestamp': result.timestamp, | |
| 'num_regions_detected': result.num_regions_detected, | |
| 'affected_area_percent': f"{result.affected_area_percent:.2f}", | |
| 'detected_symptoms': '; '.join(result.detected_symptoms), | |
| 'severity_level': result.severity_level, | |
| 'severity_label': result.severity_label, | |
| 'severity_confidence': f"{result.severity_confidence:.3f}", | |
| 'disease_name': result.disease_name, | |
| 'disease_type': result.disease_type, | |
| 'disease_confidence': f"{result.disease_confidence:.3f}", | |
| 'organic_treatments': '; '.join(result.organic_treatments[:3]), | |
| 'urgency': result.urgency | |
| } | |
| writer.writerow(row) | |
| def _export_json(self, results: List[DiagnosticResult], output_path: Path): | |
| """Export to JSON.""" | |
| data = [] | |
| for result in results: | |
| d = asdict(result) | |
| # Remove numpy arrays | |
| d.pop('segmentation_masks', None) | |
| d.pop('segmentation_scores', None) | |
| data.append(d) | |
| with open(output_path, 'w') as f: | |
| json.dump(data, f, indent=2, default=str) | |
| def generate_summary_report( | |
| self, | |
| results: List[DiagnosticResult] | |
| ) -> str: | |
| """ | |
| Generate a summary report for batch results. | |
| Args: | |
| results: List of diagnostic results | |
| Returns: | |
| Formatted summary report string | |
| """ | |
| if not results: | |
| return "No results to summarize." | |
| # Calculate statistics | |
| total = len(results) | |
| healthy = sum(1 for r in results if r.severity_level == 0) | |
| mild = sum(1 for r in results if r.severity_level == 1) | |
| moderate = sum(1 for r in results if r.severity_level == 2) | |
| severe = sum(1 for r in results if r.severity_level == 3) | |
| # Disease frequency | |
| disease_counts = {} | |
| for r in results: | |
| disease_counts[r.disease_name] = disease_counts.get(r.disease_name, 0) + 1 | |
| # Average affected area | |
| avg_affected = np.mean([r.affected_area_percent for r in results]) | |
| report = f""" | |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| β BATCH DIAGNOSIS SUMMARY β | |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| π OVERALL STATISTICS | |
| Total Images Analyzed: {total} | |
| Severity Distribution: | |
| βββ π’ Healthy: {healthy} ({healthy/total*100:.1f}%) | |
| βββ π‘ Mild: {mild} ({mild/total*100:.1f}%) | |
| βββ π Moderate: {moderate} ({moderate/total*100:.1f}%) | |
| βββ π΄ Severe: {severe} ({severe/total*100:.1f}%) | |
| Average Affected Area: {avg_affected:.1f}% | |
| π¦ DISEASE FREQUENCY | |
| """ | |
| for disease, count in sorted(disease_counts.items(), key=lambda x: -x[1]): | |
| report += f" β’ {disease}: {count} ({count/total*100:.1f}%)\n" | |
| # Urgent cases | |
| urgent = [r for r in results if r.urgency in ['high', 'critical']] | |
| if urgent: | |
| report += f""" | |
| β οΈ URGENT ATTENTION REQUIRED | |
| {len(urgent)} images require immediate attention: | |
| """ | |
| for r in urgent[:5]: # Show top 5 | |
| report += f" β’ {Path(r.image_path).name}: {r.disease_name} ({r.urgency})\n" | |
| report += """ | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| """ | |
| return report | |
| def quick_diagnose( | |
| image_path: str, | |
| use_mock: bool = True | |
| ) -> DiagnosticResult: | |
| """ | |
| Quick diagnosis function for simple use cases. | |
| Args: | |
| image_path: Path to image | |
| use_mock: Use mock models (for testing without SAM 3) | |
| Returns: | |
| DiagnosticResult | |
| """ | |
| pipeline = CropDoctorPipeline( | |
| use_mock_sam3=use_mock, | |
| use_llm=False | |
| ) | |
| return pipeline.diagnose(image_path) | |
| if __name__ == "__main__": | |
| # Test the pipeline with mock | |
| print("Testing CropDoctor Pipeline...") | |
| # Create test image | |
| test_img = Image.new("RGB", (640, 480), color=(139, 69, 19)) | |
| test_img.save("/tmp/test_leaf.jpg") | |
| # Run pipeline | |
| pipeline = CropDoctorPipeline(use_mock_sam3=True, use_llm=False) | |
| result = pipeline.diagnose("/tmp/test_leaf.jpg") | |
| print(f"\nπ Diagnosis Results:") | |
| print(f" Disease: {result.disease_name}") | |
| print(f" Type: {result.disease_type}") | |
| print(f" Severity: {result.severity_label} (Level {result.severity_level})") | |
| print(f" Affected Area: {result.affected_area_percent:.1f}%") | |
| print(f" Urgency: {result.urgency}") | |
| print(f"\nπΏ Recommended Treatments:") | |
| for t in result.organic_treatments[:3]: | |
| print(f" β’ {t}") | |