""" CropDoctor Pipeline Module ========================== This module integrates all components into a unified diagnostic pipeline: 1. SAM 3 Segmentation - Detect disease regions 2. Severity Classification - Assess severity level 3. Treatment Recommendation - Generate actionable advice Provides both single-image and batch processing capabilities. """ import torch import numpy as np from PIL import Image from pathlib import Path from typing import Dict, List, Optional, Union, Tuple from dataclasses import dataclass, asdict import json import csv from datetime import datetime import logging from .sam3_segmentation import SAM3Segmenter, create_segmenter, SegmentationResult from .severity_classifier import SeverityClassifier, SeverityPrediction, SEVERITY_LABELS from .treatment_recommender import TreatmentRecommender, TreatmentRecommendation from .leaf_segmenter import SAM2LeafSegmenter logger = logging.getLogger(__name__) @dataclass class DiagnosticResult: """Complete diagnostic result for an image.""" # Image info image_path: str timestamp: str # Segmentation results num_regions_detected: int affected_area_percent: float detected_symptoms: List[str] # Severity assessment severity_level: int severity_label: str severity_confidence: float # Treatment recommendations disease_name: str disease_type: str disease_confidence: float organic_treatments: List[str] chemical_treatments: List[str] preventive_measures: List[str] treatment_timing: str urgency: str # Raw data for further analysis segmentation_masks: Optional[np.ndarray] = None segmentation_scores: Optional[np.ndarray] = None class CropDoctorPipeline: """ End-to-end pipeline for plant disease diagnosis. Integrates SAM 3 segmentation, severity classification, and LLM-based treatment recommendations into a single workflow. Example: >>> pipeline = CropDoctorPipeline() >>> result = pipeline.diagnose("path/to/leaf.jpg") >>> print(f"Disease: {result.disease_name}") >>> print(f"Severity: {result.severity_label}") >>> print(f"Treatment: {result.organic_treatments[0]}") """ def __init__( self, sam3_checkpoint: str = "models/sam3/sam3.pt", sam3_config: str = "configs/sam3_config.yaml", classifier_checkpoint: Optional[str] = None, use_llm: bool = True, anthropic_api_key: Optional[str] = None, device: Optional[str] = None, use_mock_sam3: bool = False, use_rfdetr: bool = False, rfdetr_checkpoint: str = "models/rfdetr/checkpoint_best_total.pth", rfdetr_model_size: str = "medium", use_leaf_segmentation: bool = False, sam2_checkpoint: str = "models/sam2/sam2.1_hiera_small.pt" ): """ Initialize the CropDoctor pipeline. Args: sam3_checkpoint: Path to SAM 3 model checkpoint sam3_config: Path to SAM 3 configuration classifier_checkpoint: Path to severity classifier checkpoint use_llm: Whether to use Claude API for recommendations anthropic_api_key: Optional API key for Claude device: Device to use (auto-detected if None) use_mock_sam3: Use mock SAM 3 for testing without model use_rfdetr: Use RF-DETR for detection (recommended) rfdetr_checkpoint: Path to trained RF-DETR model rfdetr_model_size: RF-DETR model size (nano, small, medium, base) use_leaf_segmentation: Use SAM2 to segment leaf before detection sam2_checkpoint: Path to SAM2 checkpoint for leaf segmentation """ # Set device if device is None: self.device = "cuda" if torch.cuda.is_available() else "cpu" else: self.device = device logger.info(f"Initializing CropDoctor Pipeline on {self.device}") # Initialize components (lazy loading) self._segmenter = None self._classifier = None self._recommender = None self._leaf_segmenter = None # Store config self.sam3_checkpoint = sam3_checkpoint self.sam3_config = sam3_config self.classifier_checkpoint = classifier_checkpoint self.use_llm = use_llm self.anthropic_api_key = anthropic_api_key self.use_mock_sam3 = use_mock_sam3 self.use_rfdetr = use_rfdetr self.rfdetr_checkpoint = rfdetr_checkpoint self.rfdetr_model_size = rfdetr_model_size self.use_leaf_segmentation = use_leaf_segmentation self.sam2_checkpoint = sam2_checkpoint # Default prompts for disease detection self.disease_prompts = [ "diseased plant tissue", "leaf with brown spots", "powdery mildew", "yellowing leaves", "pest damage", "wilted tissue", "healthy green leaf" ] @property def segmenter(self) -> SAM3Segmenter: """Lazy load segmenter (SAM 3, Mock, or RF-DETR).""" if self._segmenter is None: self._segmenter = create_segmenter( self.sam3_checkpoint, self.sam3_config, use_mock=self.use_mock_sam3, use_rfdetr=self.use_rfdetr, rfdetr_checkpoint=self.rfdetr_checkpoint, rfdetr_model_size=self.rfdetr_model_size ) return self._segmenter @property def classifier(self) -> SeverityClassifier: """Lazy load severity classifier.""" if self._classifier is None: self._classifier = SeverityClassifier( checkpoint_path=self.classifier_checkpoint, device=self.device ) return self._classifier @property def recommender(self) -> TreatmentRecommender: """Lazy load treatment recommender.""" if self._recommender is None: self._recommender = TreatmentRecommender( api_key=self.anthropic_api_key, use_llm=self.use_llm ) return self._recommender @property def leaf_segmenter(self) -> SAM2LeafSegmenter: """Lazy load SAM2 leaf segmenter.""" if self._leaf_segmenter is None: self._leaf_segmenter = SAM2LeafSegmenter( checkpoint_path=self.sam2_checkpoint, device=self.device ) return self._leaf_segmenter def diagnose( self, image: Union[str, Path, Image.Image, np.ndarray], plant_species: Optional[str] = None, analysis_profile: str = "standard", custom_prompts: Optional[List[str]] = None, return_masks: bool = False ) -> DiagnosticResult: """ Perform complete diagnosis on an image. Args: image: Input image (path, PIL Image, or numpy array) plant_species: Optional plant species for context analysis_profile: SAM 3 analysis profile custom_prompts: Optional custom prompts for segmentation return_masks: Whether to include segmentation masks in result Returns: DiagnosticResult with complete diagnosis """ timestamp = datetime.now().isoformat() image_path = str(image) if isinstance(image, (str, Path)) else "in_memory" logger.info(f"Starting diagnosis for {image_path}") # Load image if needed if isinstance(image, (str, Path)): pil_image = Image.open(image).convert("RGB") elif isinstance(image, np.ndarray): pil_image = Image.fromarray(image) else: pil_image = image # Store original image for visualization original_image = pil_image leaf_mask = None # Step 0 (optional): Leaf segmentation with SAM2 if self.use_leaf_segmentation: logger.info("Step 0: Leaf segmentation with SAM2") pil_image, leaf_mask = self.leaf_segmenter.auto_segment_leaf( pil_image, return_mask=True ) logger.info("Leaf isolated from background") # Step 1: Disease Detection (SAM3/RF-DETR) logger.info("Step 1: Disease detection") prompts = custom_prompts or self.disease_prompts seg_result = self.segmenter.segment_with_concepts(pil_image, prompts) # Step 1.5: Refine bounding boxes to proper masks using SAM2 # This converts RF-DETR rectangular boxes to precise segmentation masks if self.use_rfdetr and len(seg_result.boxes) > 0: logger.info("Step 1.5: Refining detection boxes with SAM2") refined_masks = self.leaf_segmenter.refine_boxes_to_masks( pil_image, seg_result.boxes ) # Replace rectangular masks with refined masks seg_result = SegmentationResult( masks=refined_masks, boxes=seg_result.boxes, scores=seg_result.scores, prompts=seg_result.prompts, prompt_indices=seg_result.prompt_indices ) # Calculate affected area area_stats = self.segmenter.calculate_affected_area( seg_result, healthy_prompt_idx=prompts.index("healthy green leaf") if "healthy green leaf" in prompts else None ) # Get detected symptoms (prompts with detections) detected_symptoms = [] for prompt_idx in np.unique(seg_result.prompt_indices): if prompt_idx < len(prompts): prompt = prompts[prompt_idx] if prompt != "healthy green leaf": detected_symptoms.append(prompt) # Step 2: Severity Classification logger.info("Step 2: Severity classification") # Create combined disease mask for classification if len(seg_result.masks) > 0: # Combine all disease masks (excluding healthy) disease_mask = np.zeros(seg_result.masks[0].shape, dtype=bool) for i, mask in enumerate(seg_result.masks): prompt_idx = seg_result.prompt_indices[i] if prompt_idx < len(prompts) and prompts[prompt_idx] != "healthy green leaf": disease_mask |= mask else: disease_mask = None severity_result = self.classifier.classify(pil_image, mask=disease_mask) # Override severity based on affected area if needed if area_stats["total_affected_percent"] < 1: severity_result = SeverityPrediction( severity_level=0, severity_label="healthy", confidence=0.9, probabilities={"healthy": 0.9, "mild": 0.05, "moderate": 0.03, "severe": 0.02}, affected_area_percent=area_stats["total_affected_percent"] ) elif area_stats["total_affected_percent"] < 10 and severity_result.severity_level > 1: severity_result.severity_level = 1 severity_result.severity_label = "mild" # Step 3: Treatment Recommendations logger.info("Step 3: Generating treatment recommendations") if detected_symptoms and severity_result.severity_level > 0: treatment_result = self.recommender.get_recommendation( symptoms=detected_symptoms, severity=severity_result.severity_label, plant_species=plant_species, affected_area_percent=area_stats["total_affected_percent"] ) else: # Healthy plant - no treatment needed treatment_result = TreatmentRecommendation( disease_name="No Disease Detected", disease_type="healthy", confidence=0.9, symptoms_matched=[], organic_treatments=["Continue regular care"], chemical_treatments=[], preventive_measures=[ "Maintain good air circulation", "Water at soil level", "Monitor regularly for early symptoms" ], timing="Regular monitoring recommended", urgency="low", additional_notes="Plant appears healthy. Continue preventive care." ) # Compile final result result = DiagnosticResult( image_path=image_path, timestamp=timestamp, num_regions_detected=len(seg_result.masks), affected_area_percent=area_stats["total_affected_percent"], detected_symptoms=detected_symptoms, severity_level=severity_result.severity_level, severity_label=severity_result.severity_label, severity_confidence=severity_result.confidence, disease_name=treatment_result.disease_name, disease_type=treatment_result.disease_type, disease_confidence=treatment_result.confidence, organic_treatments=treatment_result.organic_treatments, chemical_treatments=treatment_result.chemical_treatments, preventive_measures=treatment_result.preventive_measures, treatment_timing=treatment_result.timing, urgency=treatment_result.urgency, segmentation_masks=seg_result.masks if return_masks else None, segmentation_scores=seg_result.scores if return_masks else None ) logger.info(f"Diagnosis complete: {result.disease_name} ({result.severity_label})") return result def batch_diagnose( self, image_folder: Union[str, Path], output_dir: Optional[Union[str, Path]] = None, plant_species: Optional[str] = None, save_visualizations: bool = True, file_extensions: List[str] = [".jpg", ".jpeg", ".png"] ) -> List[DiagnosticResult]: """ Process multiple images from a folder. Args: image_folder: Path to folder containing images output_dir: Where to save results (optional) plant_species: Plant species for all images save_visualizations: Whether to save visualization images file_extensions: Image file extensions to process Returns: List of DiagnosticResult for each image """ image_folder = Path(image_folder) if output_dir: output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) # Find all images images = [] for ext in file_extensions: images.extend(image_folder.glob(f"*{ext}")) images.extend(image_folder.glob(f"*{ext.upper()}")) logger.info(f"Found {len(images)} images to process") results = [] for i, img_path in enumerate(images): logger.info(f"Processing {i+1}/{len(images)}: {img_path.name}") try: result = self.diagnose( img_path, plant_species=plant_species, return_masks=save_visualizations ) results.append(result) # Save visualization if requested if save_visualizations and output_dir: self._save_visualization( img_path, result, output_dir / f"{img_path.stem}_diagnosis.png" ) except Exception as e: logger.error(f"Error processing {img_path}: {e}") continue logger.info(f"Batch processing complete: {len(results)}/{len(images)} successful") return results def _save_visualization( self, image_path: Path, result: DiagnosticResult, output_path: Path ): """Save diagnostic visualization.""" # Import visualization module from .visualization import create_diagnostic_visualization image = Image.open(image_path) fig = create_diagnostic_visualization( image, result.segmentation_masks, result.severity_label, result.disease_name, result.affected_area_percent ) fig.savefig(output_path, dpi=150, bbox_inches='tight') import matplotlib.pyplot as plt plt.close(fig) def export_report( self, results: List[DiagnosticResult], output_path: Union[str, Path], format: str = "csv" ): """ Export results to file. Args: results: List of diagnostic results output_path: Output file path format: Output format ("csv", "json") """ output_path = Path(output_path) if format == "csv": self._export_csv(results, output_path) elif format == "json": self._export_json(results, output_path) else: raise ValueError(f"Unknown format: {format}") logger.info(f"Report exported to {output_path}") def _export_csv(self, results: List[DiagnosticResult], output_path: Path): """Export to CSV.""" with open(output_path, 'w', newline='') as f: if not results: return # Get fields (excluding numpy arrays) fields = [ 'image_path', 'timestamp', 'num_regions_detected', 'affected_area_percent', 'detected_symptoms', 'severity_level', 'severity_label', 'severity_confidence', 'disease_name', 'disease_type', 'disease_confidence', 'organic_treatments', 'urgency' ] writer = csv.DictWriter(f, fieldnames=fields) writer.writeheader() for result in results: row = { 'image_path': result.image_path, 'timestamp': result.timestamp, 'num_regions_detected': result.num_regions_detected, 'affected_area_percent': f"{result.affected_area_percent:.2f}", 'detected_symptoms': '; '.join(result.detected_symptoms), 'severity_level': result.severity_level, 'severity_label': result.severity_label, 'severity_confidence': f"{result.severity_confidence:.3f}", 'disease_name': result.disease_name, 'disease_type': result.disease_type, 'disease_confidence': f"{result.disease_confidence:.3f}", 'organic_treatments': '; '.join(result.organic_treatments[:3]), 'urgency': result.urgency } writer.writerow(row) def _export_json(self, results: List[DiagnosticResult], output_path: Path): """Export to JSON.""" data = [] for result in results: d = asdict(result) # Remove numpy arrays d.pop('segmentation_masks', None) d.pop('segmentation_scores', None) data.append(d) with open(output_path, 'w') as f: json.dump(data, f, indent=2, default=str) def generate_summary_report( self, results: List[DiagnosticResult] ) -> str: """ Generate a summary report for batch results. Args: results: List of diagnostic results Returns: Formatted summary report string """ if not results: return "No results to summarize." # Calculate statistics total = len(results) healthy = sum(1 for r in results if r.severity_level == 0) mild = sum(1 for r in results if r.severity_level == 1) moderate = sum(1 for r in results if r.severity_level == 2) severe = sum(1 for r in results if r.severity_level == 3) # Disease frequency disease_counts = {} for r in results: disease_counts[r.disease_name] = disease_counts.get(r.disease_name, 0) + 1 # Average affected area avg_affected = np.mean([r.affected_area_percent for r in results]) report = f""" ╔══════════════════════════════════════════════════════════════╗ ║ BATCH DIAGNOSIS SUMMARY ║ ╚══════════════════════════════════════════════════════════════╝ 📊 OVERALL STATISTICS Total Images Analyzed: {total} Severity Distribution: ├── 🟢 Healthy: {healthy} ({healthy/total*100:.1f}%) ├── 🟡 Mild: {mild} ({mild/total*100:.1f}%) ├── 🟠 Moderate: {moderate} ({moderate/total*100:.1f}%) └── 🔴 Severe: {severe} ({severe/total*100:.1f}%) Average Affected Area: {avg_affected:.1f}% 🦠 DISEASE FREQUENCY """ for disease, count in sorted(disease_counts.items(), key=lambda x: -x[1]): report += f" • {disease}: {count} ({count/total*100:.1f}%)\n" # Urgent cases urgent = [r for r in results if r.urgency in ['high', 'critical']] if urgent: report += f""" ⚠️ URGENT ATTENTION REQUIRED {len(urgent)} images require immediate attention: """ for r in urgent[:5]: # Show top 5 report += f" • {Path(r.image_path).name}: {r.disease_name} ({r.urgency})\n" report += """ ═══════════════════════════════════════════════════════════════ """ return report def quick_diagnose( image_path: str, use_mock: bool = True ) -> DiagnosticResult: """ Quick diagnosis function for simple use cases. Args: image_path: Path to image use_mock: Use mock models (for testing without SAM 3) Returns: DiagnosticResult """ pipeline = CropDoctorPipeline( use_mock_sam3=use_mock, use_llm=False ) return pipeline.diagnose(image_path) if __name__ == "__main__": # Test the pipeline with mock print("Testing CropDoctor Pipeline...") # Create test image test_img = Image.new("RGB", (640, 480), color=(139, 69, 19)) test_img.save("/tmp/test_leaf.jpg") # Run pipeline pipeline = CropDoctorPipeline(use_mock_sam3=True, use_llm=False) result = pipeline.diagnose("/tmp/test_leaf.jpg") print(f"\n📋 Diagnosis Results:") print(f" Disease: {result.disease_name}") print(f" Type: {result.disease_type}") print(f" Severity: {result.severity_label} (Level {result.severity_level})") print(f" Affected Area: {result.affected_area_percent:.1f}%") print(f" Urgency: {result.urgency}") print(f"\n🌿 Recommended Treatments:") for t in result.organic_treatments[:3]: print(f" • {t}")