cropscan-space / src /pipeline.py
davidsv
Add disease detection app with RF-DETR and SAM2
f8eb07d
"""
CropDoctor Pipeline Module
==========================
This module integrates all components into a unified diagnostic pipeline:
1. SAM 3 Segmentation - Detect disease regions
2. Severity Classification - Assess severity level
3. Treatment Recommendation - Generate actionable advice
Provides both single-image and batch processing capabilities.
"""
import torch
import numpy as np
from PIL import Image
from pathlib import Path
from typing import Dict, List, Optional, Union, Tuple
from dataclasses import dataclass, asdict
import json
import csv
from datetime import datetime
import logging
from .sam3_segmentation import SAM3Segmenter, create_segmenter, SegmentationResult
from .severity_classifier import SeverityClassifier, SeverityPrediction, SEVERITY_LABELS
from .treatment_recommender import TreatmentRecommender, TreatmentRecommendation
from .leaf_segmenter import SAM2LeafSegmenter
logger = logging.getLogger(__name__)
@dataclass
class DiagnosticResult:
"""Complete diagnostic result for an image."""
# Image info
image_path: str
timestamp: str
# Segmentation results
num_regions_detected: int
affected_area_percent: float
detected_symptoms: List[str]
# Severity assessment
severity_level: int
severity_label: str
severity_confidence: float
# Treatment recommendations
disease_name: str
disease_type: str
disease_confidence: float
organic_treatments: List[str]
chemical_treatments: List[str]
preventive_measures: List[str]
treatment_timing: str
urgency: str
# Raw data for further analysis
segmentation_masks: Optional[np.ndarray] = None
segmentation_scores: Optional[np.ndarray] = None
class CropDoctorPipeline:
"""
End-to-end pipeline for plant disease diagnosis.
Integrates SAM 3 segmentation, severity classification, and
LLM-based treatment recommendations into a single workflow.
Example:
>>> pipeline = CropDoctorPipeline()
>>> result = pipeline.diagnose("path/to/leaf.jpg")
>>> print(f"Disease: {result.disease_name}")
>>> print(f"Severity: {result.severity_label}")
>>> print(f"Treatment: {result.organic_treatments[0]}")
"""
def __init__(
self,
sam3_checkpoint: str = "models/sam3/sam3.pt",
sam3_config: str = "configs/sam3_config.yaml",
classifier_checkpoint: Optional[str] = None,
use_llm: bool = True,
anthropic_api_key: Optional[str] = None,
device: Optional[str] = None,
use_mock_sam3: bool = False,
use_rfdetr: bool = False,
rfdetr_checkpoint: str = "models/rfdetr/checkpoint_best_total.pth",
rfdetr_model_size: str = "medium",
use_leaf_segmentation: bool = False,
sam2_checkpoint: str = "models/sam2/sam2.1_hiera_small.pt"
):
"""
Initialize the CropDoctor pipeline.
Args:
sam3_checkpoint: Path to SAM 3 model checkpoint
sam3_config: Path to SAM 3 configuration
classifier_checkpoint: Path to severity classifier checkpoint
use_llm: Whether to use Claude API for recommendations
anthropic_api_key: Optional API key for Claude
device: Device to use (auto-detected if None)
use_mock_sam3: Use mock SAM 3 for testing without model
use_rfdetr: Use RF-DETR for detection (recommended)
rfdetr_checkpoint: Path to trained RF-DETR model
rfdetr_model_size: RF-DETR model size (nano, small, medium, base)
use_leaf_segmentation: Use SAM2 to segment leaf before detection
sam2_checkpoint: Path to SAM2 checkpoint for leaf segmentation
"""
# Set device
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
logger.info(f"Initializing CropDoctor Pipeline on {self.device}")
# Initialize components (lazy loading)
self._segmenter = None
self._classifier = None
self._recommender = None
self._leaf_segmenter = None
# Store config
self.sam3_checkpoint = sam3_checkpoint
self.sam3_config = sam3_config
self.classifier_checkpoint = classifier_checkpoint
self.use_llm = use_llm
self.anthropic_api_key = anthropic_api_key
self.use_mock_sam3 = use_mock_sam3
self.use_rfdetr = use_rfdetr
self.rfdetr_checkpoint = rfdetr_checkpoint
self.rfdetr_model_size = rfdetr_model_size
self.use_leaf_segmentation = use_leaf_segmentation
self.sam2_checkpoint = sam2_checkpoint
# Default prompts for disease detection
self.disease_prompts = [
"diseased plant tissue",
"leaf with brown spots",
"powdery mildew",
"yellowing leaves",
"pest damage",
"wilted tissue",
"healthy green leaf"
]
@property
def segmenter(self) -> SAM3Segmenter:
"""Lazy load segmenter (SAM 3, Mock, or RF-DETR)."""
if self._segmenter is None:
self._segmenter = create_segmenter(
self.sam3_checkpoint,
self.sam3_config,
use_mock=self.use_mock_sam3,
use_rfdetr=self.use_rfdetr,
rfdetr_checkpoint=self.rfdetr_checkpoint,
rfdetr_model_size=self.rfdetr_model_size
)
return self._segmenter
@property
def classifier(self) -> SeverityClassifier:
"""Lazy load severity classifier."""
if self._classifier is None:
self._classifier = SeverityClassifier(
checkpoint_path=self.classifier_checkpoint,
device=self.device
)
return self._classifier
@property
def recommender(self) -> TreatmentRecommender:
"""Lazy load treatment recommender."""
if self._recommender is None:
self._recommender = TreatmentRecommender(
api_key=self.anthropic_api_key,
use_llm=self.use_llm
)
return self._recommender
@property
def leaf_segmenter(self) -> SAM2LeafSegmenter:
"""Lazy load SAM2 leaf segmenter."""
if self._leaf_segmenter is None:
self._leaf_segmenter = SAM2LeafSegmenter(
checkpoint_path=self.sam2_checkpoint,
device=self.device
)
return self._leaf_segmenter
def diagnose(
self,
image: Union[str, Path, Image.Image, np.ndarray],
plant_species: Optional[str] = None,
analysis_profile: str = "standard",
custom_prompts: Optional[List[str]] = None,
return_masks: bool = False
) -> DiagnosticResult:
"""
Perform complete diagnosis on an image.
Args:
image: Input image (path, PIL Image, or numpy array)
plant_species: Optional plant species for context
analysis_profile: SAM 3 analysis profile
custom_prompts: Optional custom prompts for segmentation
return_masks: Whether to include segmentation masks in result
Returns:
DiagnosticResult with complete diagnosis
"""
timestamp = datetime.now().isoformat()
image_path = str(image) if isinstance(image, (str, Path)) else "in_memory"
logger.info(f"Starting diagnosis for {image_path}")
# Load image if needed
if isinstance(image, (str, Path)):
pil_image = Image.open(image).convert("RGB")
elif isinstance(image, np.ndarray):
pil_image = Image.fromarray(image)
else:
pil_image = image
# Store original image for visualization
original_image = pil_image
leaf_mask = None
# Step 0 (optional): Leaf segmentation with SAM2
if self.use_leaf_segmentation:
logger.info("Step 0: Leaf segmentation with SAM2")
pil_image, leaf_mask = self.leaf_segmenter.auto_segment_leaf(
pil_image,
return_mask=True
)
logger.info("Leaf isolated from background")
# Step 1: Disease Detection (SAM3/RF-DETR)
logger.info("Step 1: Disease detection")
prompts = custom_prompts or self.disease_prompts
seg_result = self.segmenter.segment_with_concepts(pil_image, prompts)
# Step 1.5: Refine bounding boxes to proper masks using SAM2
# This converts RF-DETR rectangular boxes to precise segmentation masks
if self.use_rfdetr and len(seg_result.boxes) > 0:
logger.info("Step 1.5: Refining detection boxes with SAM2")
refined_masks = self.leaf_segmenter.refine_boxes_to_masks(
pil_image,
seg_result.boxes
)
# Replace rectangular masks with refined masks
seg_result = SegmentationResult(
masks=refined_masks,
boxes=seg_result.boxes,
scores=seg_result.scores,
prompts=seg_result.prompts,
prompt_indices=seg_result.prompt_indices
)
# Calculate affected area
area_stats = self.segmenter.calculate_affected_area(
seg_result,
healthy_prompt_idx=prompts.index("healthy green leaf") if "healthy green leaf" in prompts else None
)
# Get detected symptoms (prompts with detections)
detected_symptoms = []
for prompt_idx in np.unique(seg_result.prompt_indices):
if prompt_idx < len(prompts):
prompt = prompts[prompt_idx]
if prompt != "healthy green leaf":
detected_symptoms.append(prompt)
# Step 2: Severity Classification
logger.info("Step 2: Severity classification")
# Create combined disease mask for classification
if len(seg_result.masks) > 0:
# Combine all disease masks (excluding healthy)
disease_mask = np.zeros(seg_result.masks[0].shape, dtype=bool)
for i, mask in enumerate(seg_result.masks):
prompt_idx = seg_result.prompt_indices[i]
if prompt_idx < len(prompts) and prompts[prompt_idx] != "healthy green leaf":
disease_mask |= mask
else:
disease_mask = None
severity_result = self.classifier.classify(pil_image, mask=disease_mask)
# Override severity based on affected area if needed
if area_stats["total_affected_percent"] < 1:
severity_result = SeverityPrediction(
severity_level=0,
severity_label="healthy",
confidence=0.9,
probabilities={"healthy": 0.9, "mild": 0.05, "moderate": 0.03, "severe": 0.02},
affected_area_percent=area_stats["total_affected_percent"]
)
elif area_stats["total_affected_percent"] < 10 and severity_result.severity_level > 1:
severity_result.severity_level = 1
severity_result.severity_label = "mild"
# Step 3: Treatment Recommendations
logger.info("Step 3: Generating treatment recommendations")
if detected_symptoms and severity_result.severity_level > 0:
treatment_result = self.recommender.get_recommendation(
symptoms=detected_symptoms,
severity=severity_result.severity_label,
plant_species=plant_species,
affected_area_percent=area_stats["total_affected_percent"]
)
else:
# Healthy plant - no treatment needed
treatment_result = TreatmentRecommendation(
disease_name="No Disease Detected",
disease_type="healthy",
confidence=0.9,
symptoms_matched=[],
organic_treatments=["Continue regular care"],
chemical_treatments=[],
preventive_measures=[
"Maintain good air circulation",
"Water at soil level",
"Monitor regularly for early symptoms"
],
timing="Regular monitoring recommended",
urgency="low",
additional_notes="Plant appears healthy. Continue preventive care."
)
# Compile final result
result = DiagnosticResult(
image_path=image_path,
timestamp=timestamp,
num_regions_detected=len(seg_result.masks),
affected_area_percent=area_stats["total_affected_percent"],
detected_symptoms=detected_symptoms,
severity_level=severity_result.severity_level,
severity_label=severity_result.severity_label,
severity_confidence=severity_result.confidence,
disease_name=treatment_result.disease_name,
disease_type=treatment_result.disease_type,
disease_confidence=treatment_result.confidence,
organic_treatments=treatment_result.organic_treatments,
chemical_treatments=treatment_result.chemical_treatments,
preventive_measures=treatment_result.preventive_measures,
treatment_timing=treatment_result.timing,
urgency=treatment_result.urgency,
segmentation_masks=seg_result.masks if return_masks else None,
segmentation_scores=seg_result.scores if return_masks else None
)
logger.info(f"Diagnosis complete: {result.disease_name} ({result.severity_label})")
return result
def batch_diagnose(
self,
image_folder: Union[str, Path],
output_dir: Optional[Union[str, Path]] = None,
plant_species: Optional[str] = None,
save_visualizations: bool = True,
file_extensions: List[str] = [".jpg", ".jpeg", ".png"]
) -> List[DiagnosticResult]:
"""
Process multiple images from a folder.
Args:
image_folder: Path to folder containing images
output_dir: Where to save results (optional)
plant_species: Plant species for all images
save_visualizations: Whether to save visualization images
file_extensions: Image file extensions to process
Returns:
List of DiagnosticResult for each image
"""
image_folder = Path(image_folder)
if output_dir:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Find all images
images = []
for ext in file_extensions:
images.extend(image_folder.glob(f"*{ext}"))
images.extend(image_folder.glob(f"*{ext.upper()}"))
logger.info(f"Found {len(images)} images to process")
results = []
for i, img_path in enumerate(images):
logger.info(f"Processing {i+1}/{len(images)}: {img_path.name}")
try:
result = self.diagnose(
img_path,
plant_species=plant_species,
return_masks=save_visualizations
)
results.append(result)
# Save visualization if requested
if save_visualizations and output_dir:
self._save_visualization(
img_path,
result,
output_dir / f"{img_path.stem}_diagnosis.png"
)
except Exception as e:
logger.error(f"Error processing {img_path}: {e}")
continue
logger.info(f"Batch processing complete: {len(results)}/{len(images)} successful")
return results
def _save_visualization(
self,
image_path: Path,
result: DiagnosticResult,
output_path: Path
):
"""Save diagnostic visualization."""
# Import visualization module
from .visualization import create_diagnostic_visualization
image = Image.open(image_path)
fig = create_diagnostic_visualization(
image,
result.segmentation_masks,
result.severity_label,
result.disease_name,
result.affected_area_percent
)
fig.savefig(output_path, dpi=150, bbox_inches='tight')
import matplotlib.pyplot as plt
plt.close(fig)
def export_report(
self,
results: List[DiagnosticResult],
output_path: Union[str, Path],
format: str = "csv"
):
"""
Export results to file.
Args:
results: List of diagnostic results
output_path: Output file path
format: Output format ("csv", "json")
"""
output_path = Path(output_path)
if format == "csv":
self._export_csv(results, output_path)
elif format == "json":
self._export_json(results, output_path)
else:
raise ValueError(f"Unknown format: {format}")
logger.info(f"Report exported to {output_path}")
def _export_csv(self, results: List[DiagnosticResult], output_path: Path):
"""Export to CSV."""
with open(output_path, 'w', newline='') as f:
if not results:
return
# Get fields (excluding numpy arrays)
fields = [
'image_path', 'timestamp', 'num_regions_detected',
'affected_area_percent', 'detected_symptoms',
'severity_level', 'severity_label', 'severity_confidence',
'disease_name', 'disease_type', 'disease_confidence',
'organic_treatments', 'urgency'
]
writer = csv.DictWriter(f, fieldnames=fields)
writer.writeheader()
for result in results:
row = {
'image_path': result.image_path,
'timestamp': result.timestamp,
'num_regions_detected': result.num_regions_detected,
'affected_area_percent': f"{result.affected_area_percent:.2f}",
'detected_symptoms': '; '.join(result.detected_symptoms),
'severity_level': result.severity_level,
'severity_label': result.severity_label,
'severity_confidence': f"{result.severity_confidence:.3f}",
'disease_name': result.disease_name,
'disease_type': result.disease_type,
'disease_confidence': f"{result.disease_confidence:.3f}",
'organic_treatments': '; '.join(result.organic_treatments[:3]),
'urgency': result.urgency
}
writer.writerow(row)
def _export_json(self, results: List[DiagnosticResult], output_path: Path):
"""Export to JSON."""
data = []
for result in results:
d = asdict(result)
# Remove numpy arrays
d.pop('segmentation_masks', None)
d.pop('segmentation_scores', None)
data.append(d)
with open(output_path, 'w') as f:
json.dump(data, f, indent=2, default=str)
def generate_summary_report(
self,
results: List[DiagnosticResult]
) -> str:
"""
Generate a summary report for batch results.
Args:
results: List of diagnostic results
Returns:
Formatted summary report string
"""
if not results:
return "No results to summarize."
# Calculate statistics
total = len(results)
healthy = sum(1 for r in results if r.severity_level == 0)
mild = sum(1 for r in results if r.severity_level == 1)
moderate = sum(1 for r in results if r.severity_level == 2)
severe = sum(1 for r in results if r.severity_level == 3)
# Disease frequency
disease_counts = {}
for r in results:
disease_counts[r.disease_name] = disease_counts.get(r.disease_name, 0) + 1
# Average affected area
avg_affected = np.mean([r.affected_area_percent for r in results])
report = f"""
╔══════════════════════════════════════════════════════════════╗
β•‘ BATCH DIAGNOSIS SUMMARY β•‘
β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
πŸ“Š OVERALL STATISTICS
Total Images Analyzed: {total}
Severity Distribution:
β”œβ”€β”€ 🟒 Healthy: {healthy} ({healthy/total*100:.1f}%)
β”œβ”€β”€ 🟑 Mild: {mild} ({mild/total*100:.1f}%)
β”œβ”€β”€ 🟠 Moderate: {moderate} ({moderate/total*100:.1f}%)
└── πŸ”΄ Severe: {severe} ({severe/total*100:.1f}%)
Average Affected Area: {avg_affected:.1f}%
🦠 DISEASE FREQUENCY
"""
for disease, count in sorted(disease_counts.items(), key=lambda x: -x[1]):
report += f" β€’ {disease}: {count} ({count/total*100:.1f}%)\n"
# Urgent cases
urgent = [r for r in results if r.urgency in ['high', 'critical']]
if urgent:
report += f"""
⚠️ URGENT ATTENTION REQUIRED
{len(urgent)} images require immediate attention:
"""
for r in urgent[:5]: # Show top 5
report += f" β€’ {Path(r.image_path).name}: {r.disease_name} ({r.urgency})\n"
report += """
═══════════════════════════════════════════════════════════════
"""
return report
def quick_diagnose(
image_path: str,
use_mock: bool = True
) -> DiagnosticResult:
"""
Quick diagnosis function for simple use cases.
Args:
image_path: Path to image
use_mock: Use mock models (for testing without SAM 3)
Returns:
DiagnosticResult
"""
pipeline = CropDoctorPipeline(
use_mock_sam3=use_mock,
use_llm=False
)
return pipeline.diagnose(image_path)
if __name__ == "__main__":
# Test the pipeline with mock
print("Testing CropDoctor Pipeline...")
# Create test image
test_img = Image.new("RGB", (640, 480), color=(139, 69, 19))
test_img.save("/tmp/test_leaf.jpg")
# Run pipeline
pipeline = CropDoctorPipeline(use_mock_sam3=True, use_llm=False)
result = pipeline.diagnose("/tmp/test_leaf.jpg")
print(f"\nπŸ“‹ Diagnosis Results:")
print(f" Disease: {result.disease_name}")
print(f" Type: {result.disease_type}")
print(f" Severity: {result.severity_label} (Level {result.severity_level})")
print(f" Affected Area: {result.affected_area_percent:.1f}%")
print(f" Urgency: {result.urgency}")
print(f"\n🌿 Recommended Treatments:")
for t in result.organic_treatments[:3]:
print(f" β€’ {t}")