Spaces:
Sleeping
Sleeping
davidsv commited on
Commit Β·
f8eb07d
1
Parent(s): d7b4d0c
Add disease detection app with RF-DETR and SAM2
Browse files- app.py +213 -0
- configs/sam3_config.yaml +138 -0
- models/rfdetr/checkpoint_best_total.pth +3 -0
- models/sam2/sam2.1_hiera_small.pt +3 -0
- requirements.txt +12 -0
- src/__init__.py +98 -0
- src/__pycache__/__init__.cpython-313.pyc +0 -0
- src/__pycache__/gradio_demo.cpython-313.pyc +0 -0
- src/__pycache__/leaf_segmenter.cpython-313.pyc +0 -0
- src/__pycache__/pipeline.cpython-313.pyc +0 -0
- src/__pycache__/sam3_segmentation.cpython-313.pyc +0 -0
- src/__pycache__/severity_classifier.cpython-313.pyc +0 -0
- src/__pycache__/treatment_recommender.cpython-313.pyc +0 -0
- src/__pycache__/visualization.cpython-313.pyc +0 -0
- src/gradio_demo.py +391 -0
- src/leaf_segmenter.py +388 -0
- src/pipeline.py +631 -0
- src/sam3_segmentation.py +864 -0
- src/severity_classifier.py +590 -0
- src/treatment_recommender.py +525 -0
- src/visualization.py +460 -0
app.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
CropDoctor - Plant Disease Detection
|
| 4 |
+
Hugging Face Space Demo
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import numpy as np
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
# Model paths
|
| 14 |
+
MODEL_DIR = Path("models")
|
| 15 |
+
RFDETR_CHECKPOINT = MODEL_DIR / "rfdetr" / "checkpoint_best_total.pth"
|
| 16 |
+
SAM2_CHECKPOINT = MODEL_DIR / "sam2" / "sam2.1_hiera_small.pt"
|
| 17 |
+
|
| 18 |
+
# Lazy loaded components
|
| 19 |
+
segmenter = None
|
| 20 |
+
leaf_segmenter = None
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def load_models():
|
| 24 |
+
"""Load models on first use."""
|
| 25 |
+
global segmenter, leaf_segmenter
|
| 26 |
+
|
| 27 |
+
if segmenter is not None:
|
| 28 |
+
return
|
| 29 |
+
|
| 30 |
+
print("Loading models...")
|
| 31 |
+
|
| 32 |
+
# Import here to avoid loading at startup
|
| 33 |
+
from src.sam3_segmentation import RFDETRSegmenter
|
| 34 |
+
from src.leaf_segmenter import SAM2LeafSegmenter
|
| 35 |
+
|
| 36 |
+
# Load RF-DETR for disease detection
|
| 37 |
+
segmenter = RFDETRSegmenter(
|
| 38 |
+
checkpoint_path=str(RFDETR_CHECKPOINT),
|
| 39 |
+
model_size="medium"
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# Load SAM2 for leaf segmentation and mask refinement
|
| 43 |
+
leaf_segmenter = SAM2LeafSegmenter(
|
| 44 |
+
checkpoint_path=str(SAM2_CHECKPOINT)
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
print("Models loaded!")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def detect_disease(
|
| 51 |
+
image: np.ndarray,
|
| 52 |
+
use_leaf_segmentation: bool = True,
|
| 53 |
+
confidence_threshold: float = 0.3
|
| 54 |
+
) -> tuple:
|
| 55 |
+
"""
|
| 56 |
+
Detect plant diseases in an image.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
image: Input image
|
| 60 |
+
use_leaf_segmentation: Whether to isolate leaf first
|
| 61 |
+
confidence_threshold: Detection confidence threshold
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
Tuple of (annotated_image, detection_info)
|
| 65 |
+
"""
|
| 66 |
+
if image is None:
|
| 67 |
+
return None, "Please upload an image"
|
| 68 |
+
|
| 69 |
+
load_models()
|
| 70 |
+
|
| 71 |
+
# Convert to PIL
|
| 72 |
+
pil_image = Image.fromarray(image)
|
| 73 |
+
original_image = image.copy()
|
| 74 |
+
|
| 75 |
+
# Step 1: Leaf segmentation (optional)
|
| 76 |
+
segmented_image = None
|
| 77 |
+
leaf_mask = None
|
| 78 |
+
if use_leaf_segmentation:
|
| 79 |
+
segmented_pil, leaf_mask = leaf_segmenter.auto_segment_leaf(
|
| 80 |
+
pil_image, return_mask=True
|
| 81 |
+
)
|
| 82 |
+
segmented_image = np.array(segmented_pil)
|
| 83 |
+
detection_input = segmented_pil
|
| 84 |
+
else:
|
| 85 |
+
detection_input = pil_image
|
| 86 |
+
|
| 87 |
+
# Step 2: Disease detection with RF-DETR
|
| 88 |
+
prompts = ["diseased plant tissue", "leaf spot", "disease symptom"]
|
| 89 |
+
seg_result = segmenter.segment_with_concepts(
|
| 90 |
+
detection_input,
|
| 91 |
+
prompts,
|
| 92 |
+
confidence_threshold=confidence_threshold
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
num_detections = len(seg_result.boxes)
|
| 96 |
+
|
| 97 |
+
# Step 3: Refine boxes to masks with SAM2
|
| 98 |
+
if num_detections > 0:
|
| 99 |
+
refined_masks = leaf_segmenter.refine_boxes_to_masks(
|
| 100 |
+
detection_input,
|
| 101 |
+
seg_result.boxes
|
| 102 |
+
)
|
| 103 |
+
else:
|
| 104 |
+
refined_masks = np.zeros((0, image.shape[0], image.shape[1]), dtype=bool)
|
| 105 |
+
|
| 106 |
+
# Create visualization
|
| 107 |
+
from src.visualization import create_mask_overlay
|
| 108 |
+
|
| 109 |
+
if use_leaf_segmentation and segmented_image is not None:
|
| 110 |
+
base_image = segmented_image
|
| 111 |
+
else:
|
| 112 |
+
base_image = original_image
|
| 113 |
+
|
| 114 |
+
if num_detections > 0:
|
| 115 |
+
annotated = create_mask_overlay(base_image, refined_masks, alpha=0.5)
|
| 116 |
+
else:
|
| 117 |
+
annotated = base_image
|
| 118 |
+
|
| 119 |
+
# Format detection info
|
| 120 |
+
if num_detections == 0:
|
| 121 |
+
info = "### No disease detected\n\nThe leaf appears healthy or no disease symptoms were found."
|
| 122 |
+
else:
|
| 123 |
+
# Calculate affected area
|
| 124 |
+
total_mask = np.zeros(refined_masks[0].shape, dtype=bool)
|
| 125 |
+
for mask in refined_masks:
|
| 126 |
+
total_mask |= mask
|
| 127 |
+
|
| 128 |
+
if use_leaf_segmentation and leaf_mask is not None:
|
| 129 |
+
affected_percent = (total_mask & leaf_mask).sum() / max(leaf_mask.sum(), 1) * 100
|
| 130 |
+
else:
|
| 131 |
+
affected_percent = total_mask.sum() / (image.shape[0] * image.shape[1]) * 100
|
| 132 |
+
|
| 133 |
+
info = f"""### Disease Detected
|
| 134 |
+
|
| 135 |
+
- **Regions found**: {num_detections}
|
| 136 |
+
- **Affected area**: {affected_percent:.1f}%
|
| 137 |
+
- **Confidence threshold**: {confidence_threshold:.0%}
|
| 138 |
+
|
| 139 |
+
The colored overlays show detected disease regions.
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
return annotated, info
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def create_demo():
|
| 146 |
+
"""Create Gradio interface."""
|
| 147 |
+
|
| 148 |
+
with gr.Blocks(title="CropDoctor - Plant Disease Detection") as demo:
|
| 149 |
+
|
| 150 |
+
gr.Markdown("""
|
| 151 |
+
# CropDoctor - Plant Disease Detection
|
| 152 |
+
|
| 153 |
+
Upload a plant leaf image to detect disease regions using AI.
|
| 154 |
+
|
| 155 |
+
**Models used:**
|
| 156 |
+
- **RF-DETR**: Fine-tuned object detector for disease localization
|
| 157 |
+
- **SAM2**: Segment Anything Model 2 for precise mask generation
|
| 158 |
+
""")
|
| 159 |
+
|
| 160 |
+
with gr.Row():
|
| 161 |
+
with gr.Column():
|
| 162 |
+
input_image = gr.Image(
|
| 163 |
+
label="Upload Plant Image",
|
| 164 |
+
type="numpy",
|
| 165 |
+
height=400
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
with gr.Row():
|
| 169 |
+
leaf_seg_checkbox = gr.Checkbox(
|
| 170 |
+
value=True,
|
| 171 |
+
label="Isolate leaf (SAM2)",
|
| 172 |
+
info="Segment leaf before detection to reduce false positives"
|
| 173 |
+
)
|
| 174 |
+
confidence_slider = gr.Slider(
|
| 175 |
+
minimum=0.1,
|
| 176 |
+
maximum=0.9,
|
| 177 |
+
value=0.3,
|
| 178 |
+
step=0.05,
|
| 179 |
+
label="Confidence Threshold"
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
detect_btn = gr.Button("Detect Disease", variant="primary", size="lg")
|
| 183 |
+
|
| 184 |
+
with gr.Column():
|
| 185 |
+
output_image = gr.Image(
|
| 186 |
+
label="Detection Result",
|
| 187 |
+
type="numpy",
|
| 188 |
+
height=400
|
| 189 |
+
)
|
| 190 |
+
detection_info = gr.Markdown(
|
| 191 |
+
label="Detection Info",
|
| 192 |
+
value="Upload an image and click 'Detect Disease' to see results."
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# Event handler
|
| 196 |
+
detect_btn.click(
|
| 197 |
+
fn=detect_disease,
|
| 198 |
+
inputs=[input_image, leaf_seg_checkbox, confidence_slider],
|
| 199 |
+
outputs=[output_image, detection_info]
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
gr.Markdown("""
|
| 203 |
+
---
|
| 204 |
+
**Note**: This demo uses models trained on the PlantVillage dataset.
|
| 205 |
+
For best results, use clear images of individual plant leaves.
|
| 206 |
+
""")
|
| 207 |
+
|
| 208 |
+
return demo
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
if __name__ == "__main__":
|
| 212 |
+
demo = create_demo()
|
| 213 |
+
demo.launch()
|
configs/sam3_config.yaml
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SAM 3 Configuration for CropDoctor-Semantic
|
| 2 |
+
# ============================================
|
| 3 |
+
|
| 4 |
+
model:
|
| 5 |
+
name: "sam3"
|
| 6 |
+
checkpoint: "models/sam3/sam3.pt"
|
| 7 |
+
device: "cuda" # cuda, cpu, mps
|
| 8 |
+
half_precision: true # Use FP16 for faster inference
|
| 9 |
+
|
| 10 |
+
inference:
|
| 11 |
+
# Confidence thresholds
|
| 12 |
+
confidence_threshold: 0.25
|
| 13 |
+
presence_threshold: 0.5
|
| 14 |
+
|
| 15 |
+
# Object limits
|
| 16 |
+
max_objects_per_prompt: 50
|
| 17 |
+
min_mask_area: 100 # pixels
|
| 18 |
+
|
| 19 |
+
# Post-processing
|
| 20 |
+
apply_nms: true
|
| 21 |
+
nms_threshold: 0.5
|
| 22 |
+
|
| 23 |
+
# Disease detection prompts (ordered by specificity)
|
| 24 |
+
prompts:
|
| 25 |
+
# General disease detection
|
| 26 |
+
general:
|
| 27 |
+
- "diseased plant tissue"
|
| 28 |
+
- "infected leaf area"
|
| 29 |
+
- "plant disease symptoms"
|
| 30 |
+
- "abnormal leaf coloration"
|
| 31 |
+
|
| 32 |
+
# Healthy reference
|
| 33 |
+
healthy:
|
| 34 |
+
- "healthy green leaf"
|
| 35 |
+
- "normal plant tissue"
|
| 36 |
+
- "unaffected leaf area"
|
| 37 |
+
|
| 38 |
+
# Fungal diseases
|
| 39 |
+
fungal:
|
| 40 |
+
- "powdery mildew coating"
|
| 41 |
+
- "rust pustules on leaf"
|
| 42 |
+
- "leaf spot lesions"
|
| 43 |
+
- "anthracnose dark lesions"
|
| 44 |
+
- "downy mildew"
|
| 45 |
+
- "gray mold"
|
| 46 |
+
- "scab lesions"
|
| 47 |
+
|
| 48 |
+
# Bacterial diseases
|
| 49 |
+
bacterial:
|
| 50 |
+
- "bacterial blight"
|
| 51 |
+
- "water-soaked lesions"
|
| 52 |
+
- "angular leaf spots"
|
| 53 |
+
- "bacterial canker"
|
| 54 |
+
- "soft rot"
|
| 55 |
+
|
| 56 |
+
# Viral diseases
|
| 57 |
+
viral:
|
| 58 |
+
- "mosaic pattern on leaf"
|
| 59 |
+
- "leaf curl symptoms"
|
| 60 |
+
- "yellowing veins"
|
| 61 |
+
- "ring spots"
|
| 62 |
+
- "stunted growth"
|
| 63 |
+
|
| 64 |
+
# Nutrient deficiency
|
| 65 |
+
nutrient:
|
| 66 |
+
- "chlorosis yellowing"
|
| 67 |
+
- "interveinal chlorosis"
|
| 68 |
+
- "purple leaf coloration"
|
| 69 |
+
- "brown leaf edges"
|
| 70 |
+
- "pale green leaves"
|
| 71 |
+
|
| 72 |
+
# Pest damage
|
| 73 |
+
pest:
|
| 74 |
+
- "insect chewing damage"
|
| 75 |
+
- "leaf mining trails"
|
| 76 |
+
- "aphid colony"
|
| 77 |
+
- "caterpillar damage"
|
| 78 |
+
- "mite damage stippling"
|
| 79 |
+
- "holes in leaves"
|
| 80 |
+
|
| 81 |
+
# Prompt combinations for comprehensive analysis
|
| 82 |
+
analysis_profiles:
|
| 83 |
+
quick_scan:
|
| 84 |
+
prompts:
|
| 85 |
+
- "diseased plant tissue"
|
| 86 |
+
- "healthy green leaf"
|
| 87 |
+
description: "Fast binary healthy/diseased detection"
|
| 88 |
+
|
| 89 |
+
standard:
|
| 90 |
+
prompts:
|
| 91 |
+
- "diseased plant tissue"
|
| 92 |
+
- "leaf spot lesions"
|
| 93 |
+
- "yellowing leaves"
|
| 94 |
+
- "pest damage"
|
| 95 |
+
description: "Standard multi-symptom analysis"
|
| 96 |
+
|
| 97 |
+
comprehensive:
|
| 98 |
+
prompts:
|
| 99 |
+
- "diseased plant tissue"
|
| 100 |
+
- "powdery mildew"
|
| 101 |
+
- "rust pustules"
|
| 102 |
+
- "bacterial blight"
|
| 103 |
+
- "mosaic pattern"
|
| 104 |
+
- "chlorosis"
|
| 105 |
+
- "insect damage"
|
| 106 |
+
- "healthy tissue"
|
| 107 |
+
description: "Full diagnostic scan"
|
| 108 |
+
|
| 109 |
+
pest_focused:
|
| 110 |
+
prompts:
|
| 111 |
+
- "insect chewing damage"
|
| 112 |
+
- "leaf mining trails"
|
| 113 |
+
- "aphid infestation"
|
| 114 |
+
- "caterpillar damage"
|
| 115 |
+
- "mite stippling"
|
| 116 |
+
description: "Pest-specific detection"
|
| 117 |
+
|
| 118 |
+
# Visualization settings
|
| 119 |
+
visualization:
|
| 120 |
+
mask_alpha: 0.5
|
| 121 |
+
colormap: "viridis"
|
| 122 |
+
show_confidence: true
|
| 123 |
+
save_format: "png"
|
| 124 |
+
dpi: 150
|
| 125 |
+
|
| 126 |
+
# Color scheme for severity
|
| 127 |
+
severity_colors:
|
| 128 |
+
healthy: [0, 255, 0] # Green
|
| 129 |
+
mild: [255, 255, 0] # Yellow
|
| 130 |
+
moderate: [255, 165, 0] # Orange
|
| 131 |
+
severe: [255, 0, 0] # Red
|
| 132 |
+
|
| 133 |
+
# Performance optimization
|
| 134 |
+
optimization:
|
| 135 |
+
batch_size: 1
|
| 136 |
+
num_workers: 4
|
| 137 |
+
prefetch_factor: 2
|
| 138 |
+
pin_memory: true
|
models/rfdetr/checkpoint_best_total.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e24e651f9f144db066cd372fc12f24f8d3a3400d9fddb514638ca513f65d3152
|
| 3 |
+
size 133680431
|
models/sam2/sam2.1_hiera_small.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6d1aa6f30de5c92224f8172114de081d104bbd23dd9dc5c58996f0cad5dc4d38
|
| 3 |
+
size 184416285
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Requirements for Hugging Face Space
|
| 2 |
+
gradio>=4.0.0
|
| 3 |
+
torch>=2.0.0
|
| 4 |
+
torchvision>=0.15.0
|
| 5 |
+
numpy>=1.24.0
|
| 6 |
+
Pillow>=9.0.0
|
| 7 |
+
opencv-python-headless>=4.8.0
|
| 8 |
+
supervision>=0.16.0
|
| 9 |
+
rfdetr
|
| 10 |
+
sam2
|
| 11 |
+
scipy
|
| 12 |
+
PyYAML
|
src/__init__.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CropDoctor-Semantic: AI-Powered Plant Disease Diagnosis
|
| 3 |
+
========================================================
|
| 4 |
+
|
| 5 |
+
A comprehensive pipeline for plant disease diagnosis using:
|
| 6 |
+
- SAM 3 (Segment Anything Model 3) for concept-based segmentation
|
| 7 |
+
- CNN-based severity classification
|
| 8 |
+
- LLM-powered treatment recommendations
|
| 9 |
+
|
| 10 |
+
Main Components:
|
| 11 |
+
- SAM3Segmenter: Zero-shot disease region segmentation
|
| 12 |
+
- SeverityClassifier: CNN for severity assessment
|
| 13 |
+
- TreatmentRecommender: Claude API integration for treatment advice
|
| 14 |
+
- CropDoctorPipeline: End-to-end diagnostic pipeline
|
| 15 |
+
|
| 16 |
+
Quick Start:
|
| 17 |
+
>>> from src.pipeline import CropDoctorPipeline
|
| 18 |
+
>>> pipeline = CropDoctorPipeline()
|
| 19 |
+
>>> result = pipeline.diagnose("path/to/leaf.jpg")
|
| 20 |
+
>>> print(result.disease_name, result.severity_label)
|
| 21 |
+
|
| 22 |
+
For more information, see the README.md or documentation.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
__version__ = "0.1.0"
|
| 26 |
+
__author__ = "CropDoctor Team"
|
| 27 |
+
|
| 28 |
+
from .sam3_segmentation import (
|
| 29 |
+
SAM3Segmenter,
|
| 30 |
+
MockSAM3Segmenter,
|
| 31 |
+
create_segmenter,
|
| 32 |
+
SegmentationResult
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
from .severity_classifier import (
|
| 36 |
+
SeverityClassifier,
|
| 37 |
+
SeverityClassifierCNN,
|
| 38 |
+
SeverityPrediction,
|
| 39 |
+
SEVERITY_LABELS,
|
| 40 |
+
SEVERITY_DESCRIPTIONS
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
from .treatment_recommender import (
|
| 44 |
+
TreatmentRecommender,
|
| 45 |
+
TreatmentRecommendation,
|
| 46 |
+
DISEASE_DATABASE
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
from .pipeline import (
|
| 50 |
+
CropDoctorPipeline,
|
| 51 |
+
DiagnosticResult,
|
| 52 |
+
quick_diagnose
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
from .visualization import (
|
| 56 |
+
create_diagnostic_visualization,
|
| 57 |
+
create_mask_overlay,
|
| 58 |
+
create_severity_heatmap,
|
| 59 |
+
create_comparison_view,
|
| 60 |
+
create_treatment_card,
|
| 61 |
+
save_visualization
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
__all__ = [
|
| 65 |
+
# Version
|
| 66 |
+
'__version__',
|
| 67 |
+
|
| 68 |
+
# Segmentation
|
| 69 |
+
'SAM3Segmenter',
|
| 70 |
+
'MockSAM3Segmenter',
|
| 71 |
+
'create_segmenter',
|
| 72 |
+
'SegmentationResult',
|
| 73 |
+
|
| 74 |
+
# Classification
|
| 75 |
+
'SeverityClassifier',
|
| 76 |
+
'SeverityClassifierCNN',
|
| 77 |
+
'SeverityPrediction',
|
| 78 |
+
'SEVERITY_LABELS',
|
| 79 |
+
'SEVERITY_DESCRIPTIONS',
|
| 80 |
+
|
| 81 |
+
# Recommendations
|
| 82 |
+
'TreatmentRecommender',
|
| 83 |
+
'TreatmentRecommendation',
|
| 84 |
+
'DISEASE_DATABASE',
|
| 85 |
+
|
| 86 |
+
# Pipeline
|
| 87 |
+
'CropDoctorPipeline',
|
| 88 |
+
'DiagnosticResult',
|
| 89 |
+
'quick_diagnose',
|
| 90 |
+
|
| 91 |
+
# Visualization
|
| 92 |
+
'create_diagnostic_visualization',
|
| 93 |
+
'create_mask_overlay',
|
| 94 |
+
'create_severity_heatmap',
|
| 95 |
+
'create_comparison_view',
|
| 96 |
+
'create_treatment_card',
|
| 97 |
+
'save_visualization',
|
| 98 |
+
]
|
src/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (2.06 kB). View file
|
|
|
src/__pycache__/gradio_demo.cpython-313.pyc
ADDED
|
Binary file (13 kB). View file
|
|
|
src/__pycache__/leaf_segmenter.cpython-313.pyc
ADDED
|
Binary file (13.6 kB). View file
|
|
|
src/__pycache__/pipeline.cpython-313.pyc
ADDED
|
Binary file (25.8 kB). View file
|
|
|
src/__pycache__/sam3_segmentation.cpython-313.pyc
ADDED
|
Binary file (35.1 kB). View file
|
|
|
src/__pycache__/severity_classifier.cpython-313.pyc
ADDED
|
Binary file (24.2 kB). View file
|
|
|
src/__pycache__/treatment_recommender.cpython-313.pyc
ADDED
|
Binary file (17 kB). View file
|
|
|
src/__pycache__/visualization.cpython-313.pyc
ADDED
|
Binary file (18.3 kB). View file
|
|
|
src/gradio_demo.py
ADDED
|
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Gradio Demo for CropDoctor-Semantic
|
| 4 |
+
====================================
|
| 5 |
+
|
| 6 |
+
Interactive web interface for plant disease diagnosis.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python src/gradio_demo.py
|
| 10 |
+
|
| 11 |
+
Then open http://localhost:7860 in your browser.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import gradio as gr
|
| 15 |
+
import numpy as np
|
| 16 |
+
from PIL import Image
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from datetime import datetime
|
| 19 |
+
import sys
|
| 20 |
+
|
| 21 |
+
# Add src to path
|
| 22 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 23 |
+
|
| 24 |
+
from src.pipeline import CropDoctorPipeline, DiagnosticResult
|
| 25 |
+
from src.visualization import create_mask_overlay, SEVERITY_COLORS
|
| 26 |
+
from src.treatment_recommender import TreatmentRecommender
|
| 27 |
+
|
| 28 |
+
# Initialize pipeline with mock for demo
|
| 29 |
+
# Set use_mock_sam3=False when you have SAM 3 installed
|
| 30 |
+
pipeline = None
|
| 31 |
+
|
| 32 |
+
# Output directory for saving results
|
| 33 |
+
OUTPUT_DIR = Path("output")
|
| 34 |
+
OUTPUT_DIR.mkdir(exist_ok=True)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get_pipeline(use_leaf_segmentation: bool = False):
|
| 38 |
+
"""Lazy load pipeline."""
|
| 39 |
+
global pipeline
|
| 40 |
+
if pipeline is None:
|
| 41 |
+
pipeline = CropDoctorPipeline(
|
| 42 |
+
classifier_checkpoint="models/severity_classifier/best.pt",
|
| 43 |
+
use_mock_sam3=False,
|
| 44 |
+
use_rfdetr=True,
|
| 45 |
+
rfdetr_checkpoint="models/rfdetr/checkpoint_best_total.pth",
|
| 46 |
+
use_llm=False,
|
| 47 |
+
use_leaf_segmentation=use_leaf_segmentation,
|
| 48 |
+
sam2_checkpoint="models/sam2/sam2.1_hiera_small.pt"
|
| 49 |
+
)
|
| 50 |
+
# Update leaf segmentation setting
|
| 51 |
+
pipeline.use_leaf_segmentation = use_leaf_segmentation
|
| 52 |
+
return pipeline
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def diagnose_image(
|
| 56 |
+
image: np.ndarray,
|
| 57 |
+
plant_species: str,
|
| 58 |
+
analysis_profile: str,
|
| 59 |
+
use_leaf_segmentation: bool = True
|
| 60 |
+
) -> tuple:
|
| 61 |
+
"""
|
| 62 |
+
Main diagnosis function for Gradio.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
image: Input image from Gradio
|
| 66 |
+
plant_species: Selected plant species
|
| 67 |
+
analysis_profile: Analysis profile to use
|
| 68 |
+
use_leaf_segmentation: Whether to segment leaf before detection
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
Tuple of (annotated_image, diagnosis_text, treatment_text, severity_label)
|
| 72 |
+
"""
|
| 73 |
+
if image is None:
|
| 74 |
+
return None, "Please upload an image", "", "unknown"
|
| 75 |
+
|
| 76 |
+
# Generate unique timestamp for this request
|
| 77 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 78 |
+
request_dir = OUTPUT_DIR / timestamp
|
| 79 |
+
request_dir.mkdir(exist_ok=True)
|
| 80 |
+
|
| 81 |
+
# Save original image
|
| 82 |
+
original_pil = Image.fromarray(image)
|
| 83 |
+
original_pil.save(request_dir / "original.png")
|
| 84 |
+
|
| 85 |
+
# Get pipeline with leaf segmentation option
|
| 86 |
+
pipe = get_pipeline(use_leaf_segmentation=use_leaf_segmentation)
|
| 87 |
+
|
| 88 |
+
# Get segmented leaf image if enabled
|
| 89 |
+
segmented_image = None
|
| 90 |
+
if use_leaf_segmentation:
|
| 91 |
+
segmented_pil, leaf_mask = pipe.leaf_segmenter.auto_segment_leaf(
|
| 92 |
+
original_pil, return_mask=True
|
| 93 |
+
)
|
| 94 |
+
segmented_image = np.array(segmented_pil)
|
| 95 |
+
# Save segmented leaf (no annotations yet)
|
| 96 |
+
segmented_pil.save(request_dir / "segmented_leaf.png")
|
| 97 |
+
|
| 98 |
+
# Run diagnosis
|
| 99 |
+
result = pipe.diagnose(
|
| 100 |
+
image,
|
| 101 |
+
plant_species=plant_species if plant_species != "Auto-detect" else None,
|
| 102 |
+
analysis_profile=analysis_profile.lower().replace(" ", "_"),
|
| 103 |
+
return_masks=True
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# Create annotated images
|
| 107 |
+
# 1. Annotated original
|
| 108 |
+
if result.segmentation_masks is not None and len(result.segmentation_masks) > 0:
|
| 109 |
+
annotated_original = create_mask_overlay(
|
| 110 |
+
image,
|
| 111 |
+
result.segmentation_masks,
|
| 112 |
+
alpha=0.5
|
| 113 |
+
)
|
| 114 |
+
Image.fromarray(annotated_original).save(request_dir / "annotated_original.png")
|
| 115 |
+
|
| 116 |
+
# 2. Annotated segmented (if leaf segmentation enabled)
|
| 117 |
+
if segmented_image is not None:
|
| 118 |
+
annotated_segmented = create_mask_overlay(
|
| 119 |
+
segmented_image,
|
| 120 |
+
result.segmentation_masks,
|
| 121 |
+
alpha=0.5
|
| 122 |
+
)
|
| 123 |
+
Image.fromarray(annotated_segmented).save(request_dir / "annotated_segmented.png")
|
| 124 |
+
else:
|
| 125 |
+
annotated_segmented = annotated_original
|
| 126 |
+
else:
|
| 127 |
+
annotated_original = image
|
| 128 |
+
annotated_segmented = segmented_image if segmented_image is not None else image
|
| 129 |
+
Image.fromarray(annotated_original).save(request_dir / "annotated_original.png")
|
| 130 |
+
if segmented_image is not None:
|
| 131 |
+
Image.fromarray(segmented_image).save(request_dir / "annotated_segmented.png")
|
| 132 |
+
|
| 133 |
+
# Format diagnosis text
|
| 134 |
+
diagnosis_text = format_diagnosis(result)
|
| 135 |
+
|
| 136 |
+
# Format treatment text
|
| 137 |
+
treatment_text = format_treatment(result)
|
| 138 |
+
|
| 139 |
+
# Severity for color coding
|
| 140 |
+
severity = result.severity_label.lower()
|
| 141 |
+
|
| 142 |
+
# Add output path info to diagnosis
|
| 143 |
+
diagnosis_text += f"\n\n---\n*Output saved to: `{request_dir}`*"
|
| 144 |
+
|
| 145 |
+
# Return the segmented+annotated image (or original annotated if no segmentation)
|
| 146 |
+
display_image = annotated_segmented if use_leaf_segmentation else annotated_original
|
| 147 |
+
|
| 148 |
+
return display_image, diagnosis_text, treatment_text, severity
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def format_diagnosis(result: DiagnosticResult) -> str:
|
| 152 |
+
"""Format diagnosis results as markdown."""
|
| 153 |
+
|
| 154 |
+
severity_emoji = {
|
| 155 |
+
"healthy": "π’",
|
| 156 |
+
"mild": "π‘",
|
| 157 |
+
"moderate": "π ",
|
| 158 |
+
"severe": "π΄"
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
emoji = severity_emoji.get(result.severity_label.lower(), "βͺ")
|
| 162 |
+
|
| 163 |
+
text = f"""
|
| 164 |
+
## π Diagnostic Results
|
| 165 |
+
|
| 166 |
+
### Disease Identification
|
| 167 |
+
- **Disease**: {result.disease_name}
|
| 168 |
+
- **Type**: {result.disease_type.capitalize()}
|
| 169 |
+
- **Confidence**: {result.disease_confidence:.0%}
|
| 170 |
+
|
| 171 |
+
### Severity Assessment
|
| 172 |
+
- **Level**: {emoji} **{result.severity_label.upper()}** (Level {result.severity_level}/3)
|
| 173 |
+
- **Affected Area**: {result.affected_area_percent:.1f}%
|
| 174 |
+
- **Confidence**: {result.severity_confidence:.0%}
|
| 175 |
+
|
| 176 |
+
### Detected Symptoms
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
if result.detected_symptoms:
|
| 180 |
+
for symptom in result.detected_symptoms:
|
| 181 |
+
text += f"- {symptom}\n"
|
| 182 |
+
else:
|
| 183 |
+
text += "- No disease symptoms detected\n"
|
| 184 |
+
|
| 185 |
+
text += f"""
|
| 186 |
+
### Urgency
|
| 187 |
+
**{result.urgency.upper()}** - {"Immediate action required!" if result.urgency == "critical" else "Monitor and treat as recommended."}
|
| 188 |
+
"""
|
| 189 |
+
|
| 190 |
+
return text
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def format_treatment(result: DiagnosticResult) -> str:
|
| 194 |
+
"""Format treatment recommendations as markdown."""
|
| 195 |
+
|
| 196 |
+
text = f"""
|
| 197 |
+
## πΏ Treatment Recommendations
|
| 198 |
+
|
| 199 |
+
### Organic Treatments
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
for i, treatment in enumerate(result.organic_treatments[:5], 1):
|
| 203 |
+
text += f"{i}. {treatment}\n"
|
| 204 |
+
|
| 205 |
+
if result.chemical_treatments:
|
| 206 |
+
text += "\n### Chemical Treatments\n"
|
| 207 |
+
for i, treatment in enumerate(result.chemical_treatments[:3], 1):
|
| 208 |
+
text += f"{i}. {treatment}\n"
|
| 209 |
+
|
| 210 |
+
text += "\n### Preventive Measures\n"
|
| 211 |
+
for i, measure in enumerate(result.preventive_measures[:5], 1):
|
| 212 |
+
text += f"{i}. {measure}\n"
|
| 213 |
+
|
| 214 |
+
text += f"""
|
| 215 |
+
### Application Timing
|
| 216 |
+
{result.treatment_timing}
|
| 217 |
+
|
| 218 |
+
---
|
| 219 |
+
*Note: These recommendations are AI-generated. Always consult with a local agricultural expert for critical decisions.*
|
| 220 |
+
"""
|
| 221 |
+
|
| 222 |
+
return text
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def create_demo():
|
| 226 |
+
"""Create Gradio interface."""
|
| 227 |
+
|
| 228 |
+
# Custom CSS
|
| 229 |
+
css = """
|
| 230 |
+
.severity-healthy { background-color: #2ECC71 !important; }
|
| 231 |
+
.severity-mild { background-color: #F1C40F !important; }
|
| 232 |
+
.severity-moderate { background-color: #E67E22 !important; }
|
| 233 |
+
.severity-severe { background-color: #E74C3C !important; }
|
| 234 |
+
|
| 235 |
+
.diagnosis-panel {
|
| 236 |
+
border-left: 4px solid #3498DB;
|
| 237 |
+
padding-left: 1rem;
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
.treatment-panel {
|
| 241 |
+
border-left: 4px solid #27AE60;
|
| 242 |
+
padding-left: 1rem;
|
| 243 |
+
}
|
| 244 |
+
"""
|
| 245 |
+
|
| 246 |
+
# Plant species options
|
| 247 |
+
plant_species = [
|
| 248 |
+
"Auto-detect",
|
| 249 |
+
"Apple",
|
| 250 |
+
"Grape",
|
| 251 |
+
"Tomato",
|
| 252 |
+
"Corn/Maize",
|
| 253 |
+
"Potato",
|
| 254 |
+
"Pepper",
|
| 255 |
+
"Strawberry",
|
| 256 |
+
"Cherry",
|
| 257 |
+
"Peach",
|
| 258 |
+
"Citrus",
|
| 259 |
+
"Rice",
|
| 260 |
+
"Wheat",
|
| 261 |
+
"Other"
|
| 262 |
+
]
|
| 263 |
+
|
| 264 |
+
# Analysis profiles
|
| 265 |
+
profiles = [
|
| 266 |
+
"Quick Scan",
|
| 267 |
+
"Standard",
|
| 268 |
+
"Comprehensive",
|
| 269 |
+
"Pest Focused"
|
| 270 |
+
]
|
| 271 |
+
|
| 272 |
+
# Example images (you can add real examples here)
|
| 273 |
+
examples = [
|
| 274 |
+
["examples/healthy_tomato.jpg", "Tomato", "Standard"],
|
| 275 |
+
["examples/leaf_spot.jpg", "Apple", "Comprehensive"],
|
| 276 |
+
["examples/powdery_mildew.jpg", "Grape", "Standard"],
|
| 277 |
+
]
|
| 278 |
+
|
| 279 |
+
with gr.Blocks(css=css, title="CropDoctor-Semantic") as demo:
|
| 280 |
+
|
| 281 |
+
gr.Markdown("""
|
| 282 |
+
# π± CropDoctor-Semantic
|
| 283 |
+
### AI-Powered Plant Disease Diagnosis with SAM 3
|
| 284 |
+
|
| 285 |
+
Upload an image of a plant leaf to get instant disease diagnosis and treatment recommendations.
|
| 286 |
+
|
| 287 |
+
---
|
| 288 |
+
""")
|
| 289 |
+
|
| 290 |
+
with gr.Row():
|
| 291 |
+
# Left column - Input
|
| 292 |
+
with gr.Column(scale=1):
|
| 293 |
+
input_image = gr.Image(
|
| 294 |
+
label="π· Upload Plant Image",
|
| 295 |
+
type="numpy",
|
| 296 |
+
height=400
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
with gr.Row():
|
| 300 |
+
species_dropdown = gr.Dropdown(
|
| 301 |
+
choices=plant_species,
|
| 302 |
+
value="Auto-detect",
|
| 303 |
+
label="πΏ Plant Species"
|
| 304 |
+
)
|
| 305 |
+
profile_dropdown = gr.Dropdown(
|
| 306 |
+
choices=profiles,
|
| 307 |
+
value="Standard",
|
| 308 |
+
label="π Analysis Profile"
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
leaf_seg_checkbox = gr.Checkbox(
|
| 312 |
+
value=True,
|
| 313 |
+
label="π Isoler la feuille (SAM2)",
|
| 314 |
+
info="Segmente la feuille avant la detection pour reduire les faux positifs"
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
diagnose_btn = gr.Button(
|
| 318 |
+
"π¬ Analyze Plant",
|
| 319 |
+
variant="primary",
|
| 320 |
+
size="lg"
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
# Right column - Output
|
| 324 |
+
with gr.Column(scale=1):
|
| 325 |
+
output_image = gr.Image(
|
| 326 |
+
label="π― Detection Results",
|
| 327 |
+
type="numpy",
|
| 328 |
+
height=400
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
severity_label = gr.Textbox(
|
| 332 |
+
label="Severity",
|
| 333 |
+
visible=False
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
with gr.Row():
|
| 337 |
+
# Diagnosis panel
|
| 338 |
+
with gr.Column(scale=1, elem_classes=["diagnosis-panel"]):
|
| 339 |
+
diagnosis_output = gr.Markdown(
|
| 340 |
+
label="Diagnosis",
|
| 341 |
+
value="Upload an image and click 'Analyze Plant' to see results."
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
# Treatment panel
|
| 345 |
+
with gr.Column(scale=1, elem_classes=["treatment-panel"]):
|
| 346 |
+
treatment_output = gr.Markdown(
|
| 347 |
+
label="Treatment",
|
| 348 |
+
value="Treatment recommendations will appear here."
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
# Examples section
|
| 352 |
+
gr.Markdown("### π Example Images")
|
| 353 |
+
gr.Markdown("*Click an example to try it out (examples require actual image files)*")
|
| 354 |
+
|
| 355 |
+
# Event handlers
|
| 356 |
+
diagnose_btn.click(
|
| 357 |
+
fn=diagnose_image,
|
| 358 |
+
inputs=[input_image, species_dropdown, profile_dropdown, leaf_seg_checkbox],
|
| 359 |
+
outputs=[output_image, diagnosis_output, treatment_output, severity_label]
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
# Footer
|
| 363 |
+
gr.Markdown("""
|
| 364 |
+
---
|
| 365 |
+
### About
|
| 366 |
+
|
| 367 |
+
**CropDoctor-Semantic** uses cutting-edge AI technology:
|
| 368 |
+
- **SAM 3** (Segment Anything Model 3) for concept-based segmentation
|
| 369 |
+
- **Deep Learning** for severity classification
|
| 370 |
+
- **Claude AI** for intelligent treatment recommendations
|
| 371 |
+
|
| 372 |
+
β οΈ *This is a demo version. For production use, install SAM 3 and configure the Anthropic API.*
|
| 373 |
+
|
| 374 |
+
π **Impact**: Plant diseases cause $220 billion in agricultural losses annually.
|
| 375 |
+
Early detection can significantly reduce crop losses.
|
| 376 |
+
|
| 377 |
+
---
|
| 378 |
+
*Built with β€οΈ for sustainable agriculture*
|
| 379 |
+
""")
|
| 380 |
+
|
| 381 |
+
return demo
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
if __name__ == "__main__":
|
| 385 |
+
demo = create_demo()
|
| 386 |
+
demo.launch(
|
| 387 |
+
share=False, # Set to True for public link
|
| 388 |
+
server_name="0.0.0.0",
|
| 389 |
+
server_port=7860,
|
| 390 |
+
show_error=True
|
| 391 |
+
)
|
src/leaf_segmenter.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Leaf Segmentation using SAM2.
|
| 3 |
+
|
| 4 |
+
This module provides leaf segmentation functionality to isolate leaves
|
| 5 |
+
from backgrounds before disease detection.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from typing import Optional, Tuple
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SAM2LeafSegmenter:
|
| 15 |
+
"""
|
| 16 |
+
Segments leaves from images using SAM2 (Segment Anything Model 2).
|
| 17 |
+
|
| 18 |
+
This is used as a preprocessing step to:
|
| 19 |
+
1. Isolate the leaf from the background
|
| 20 |
+
2. Create a white background image with just the leaf
|
| 21 |
+
3. Reduce false positives in disease detection
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
checkpoint_path: str = "models/sam2/sam2.1_hiera_small.pt",
|
| 27 |
+
config_file: str = "configs/sam2.1/sam2.1_hiera_s.yaml",
|
| 28 |
+
device: Optional[str] = None
|
| 29 |
+
):
|
| 30 |
+
"""
|
| 31 |
+
Initialize SAM2 leaf segmenter.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
checkpoint_path: Path to SAM2 checkpoint
|
| 35 |
+
config_file: SAM2 config file name
|
| 36 |
+
device: Device to use ('cuda', 'mps', 'cpu'). Auto-detected if None.
|
| 37 |
+
"""
|
| 38 |
+
self.checkpoint_path = checkpoint_path
|
| 39 |
+
self.config_file = config_file
|
| 40 |
+
|
| 41 |
+
if device is None:
|
| 42 |
+
if torch.cuda.is_available():
|
| 43 |
+
self.device = 'cuda'
|
| 44 |
+
elif torch.backends.mps.is_available():
|
| 45 |
+
self.device = 'mps'
|
| 46 |
+
else:
|
| 47 |
+
self.device = 'cpu'
|
| 48 |
+
else:
|
| 49 |
+
self.device = device
|
| 50 |
+
|
| 51 |
+
self.model = None
|
| 52 |
+
self.predictor = None
|
| 53 |
+
|
| 54 |
+
def load_model(self):
|
| 55 |
+
"""Load SAM2 model."""
|
| 56 |
+
if self.model is not None:
|
| 57 |
+
return
|
| 58 |
+
|
| 59 |
+
from sam2.build_sam import build_sam2
|
| 60 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
| 61 |
+
|
| 62 |
+
print(f"Loading SAM2 model on {self.device}...")
|
| 63 |
+
self.model = build_sam2(
|
| 64 |
+
config_file=self.config_file,
|
| 65 |
+
ckpt_path=self.checkpoint_path,
|
| 66 |
+
device=self.device
|
| 67 |
+
)
|
| 68 |
+
self.predictor = SAM2ImagePredictor(self.model)
|
| 69 |
+
print("SAM2 model loaded.")
|
| 70 |
+
|
| 71 |
+
def segment_leaf(
|
| 72 |
+
self,
|
| 73 |
+
image: Image.Image,
|
| 74 |
+
point: Optional[Tuple[int, int]] = None,
|
| 75 |
+
return_mask: bool = False
|
| 76 |
+
) -> Image.Image | Tuple[Image.Image, np.ndarray]:
|
| 77 |
+
"""
|
| 78 |
+
Segment the leaf from the image.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
image: PIL Image to segment
|
| 82 |
+
point: (x, y) point to indicate the leaf. If None, uses image center.
|
| 83 |
+
return_mask: If True, also returns the binary mask
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
Image with leaf on white background (and mask if return_mask=True)
|
| 87 |
+
"""
|
| 88 |
+
self.load_model()
|
| 89 |
+
|
| 90 |
+
# Convert to numpy array
|
| 91 |
+
image_np = np.array(image.convert('RGB'))
|
| 92 |
+
h, w = image_np.shape[:2]
|
| 93 |
+
|
| 94 |
+
# Use center point if not specified
|
| 95 |
+
if point is None:
|
| 96 |
+
point = (w // 2, h // 2)
|
| 97 |
+
|
| 98 |
+
# Set image for predictor
|
| 99 |
+
self.predictor.set_image(image_np)
|
| 100 |
+
|
| 101 |
+
# Predict mask using point prompt
|
| 102 |
+
point_coords = np.array([[point[0], point[1]]])
|
| 103 |
+
point_labels = np.array([1]) # 1 = foreground
|
| 104 |
+
|
| 105 |
+
masks, scores, _ = self.predictor.predict(
|
| 106 |
+
point_coords=point_coords,
|
| 107 |
+
point_labels=point_labels,
|
| 108 |
+
multimask_output=True
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# Select best mask (highest score)
|
| 112 |
+
best_idx = np.argmax(scores)
|
| 113 |
+
mask = masks[best_idx].astype(bool)
|
| 114 |
+
|
| 115 |
+
# Create white background image
|
| 116 |
+
result = np.ones_like(image_np) * 255 # White background
|
| 117 |
+
result[mask] = image_np[mask] # Copy leaf pixels
|
| 118 |
+
|
| 119 |
+
result_image = Image.fromarray(result.astype(np.uint8))
|
| 120 |
+
|
| 121 |
+
if return_mask:
|
| 122 |
+
return result_image, mask
|
| 123 |
+
return result_image
|
| 124 |
+
|
| 125 |
+
def segment_leaf_with_bbox(
|
| 126 |
+
self,
|
| 127 |
+
image: Image.Image,
|
| 128 |
+
bbox: Optional[Tuple[int, int, int, int]] = None,
|
| 129 |
+
return_mask: bool = False
|
| 130 |
+
) -> Image.Image | Tuple[Image.Image, np.ndarray]:
|
| 131 |
+
"""
|
| 132 |
+
Segment the leaf using a bounding box prompt.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
image: PIL Image to segment
|
| 136 |
+
bbox: (x1, y1, x2, y2) bounding box. If None, uses full image.
|
| 137 |
+
return_mask: If True, also returns the binary mask
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
Image with leaf on white background (and mask if return_mask=True)
|
| 141 |
+
"""
|
| 142 |
+
self.load_model()
|
| 143 |
+
|
| 144 |
+
# Convert to numpy array
|
| 145 |
+
image_np = np.array(image.convert('RGB'))
|
| 146 |
+
h, w = image_np.shape[:2]
|
| 147 |
+
|
| 148 |
+
# Use full image bbox if not specified
|
| 149 |
+
if bbox is None:
|
| 150 |
+
# Use slightly inset bbox to focus on leaf
|
| 151 |
+
margin = min(w, h) // 20
|
| 152 |
+
bbox = (margin, margin, w - margin, h - margin)
|
| 153 |
+
|
| 154 |
+
# Set image for predictor
|
| 155 |
+
self.predictor.set_image(image_np)
|
| 156 |
+
|
| 157 |
+
# Predict mask using box prompt
|
| 158 |
+
box = np.array([bbox])
|
| 159 |
+
|
| 160 |
+
masks, scores, _ = self.predictor.predict(
|
| 161 |
+
box=box,
|
| 162 |
+
multimask_output=True
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# Select best mask (highest score)
|
| 166 |
+
best_idx = np.argmax(scores)
|
| 167 |
+
mask = masks[best_idx].astype(bool)
|
| 168 |
+
|
| 169 |
+
# Create white background image
|
| 170 |
+
result = np.ones_like(image_np) * 255 # White background
|
| 171 |
+
result[mask] = image_np[mask] # Copy leaf pixels
|
| 172 |
+
|
| 173 |
+
result_image = Image.fromarray(result.astype(np.uint8))
|
| 174 |
+
|
| 175 |
+
if return_mask:
|
| 176 |
+
return result_image, mask
|
| 177 |
+
return result_image
|
| 178 |
+
|
| 179 |
+
def auto_segment_leaf(
|
| 180 |
+
self,
|
| 181 |
+
image: Image.Image,
|
| 182 |
+
return_mask: bool = False
|
| 183 |
+
) -> Image.Image | Tuple[Image.Image, np.ndarray]:
|
| 184 |
+
"""
|
| 185 |
+
Automatically segment the main leaf/plant from the image.
|
| 186 |
+
|
| 187 |
+
Uses multiple strategies to find the best segmentation:
|
| 188 |
+
1. Center point
|
| 189 |
+
2. Multiple points in a grid
|
| 190 |
+
3. Green color detection for better point selection
|
| 191 |
+
4. Selects the largest coherent mask
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
image: PIL Image to segment
|
| 195 |
+
return_mask: If True, also returns the binary mask
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
Image with leaf on white background (and mask if return_mask=True)
|
| 199 |
+
"""
|
| 200 |
+
self.load_model()
|
| 201 |
+
|
| 202 |
+
# Convert to numpy array
|
| 203 |
+
image_np = np.array(image.convert('RGB'))
|
| 204 |
+
h, w = image_np.shape[:2]
|
| 205 |
+
|
| 206 |
+
# Set image for predictor
|
| 207 |
+
self.predictor.set_image(image_np)
|
| 208 |
+
|
| 209 |
+
# Try to find a good point on the leaf using green color detection
|
| 210 |
+
# Convert to HSV for better color detection
|
| 211 |
+
from PIL import ImageFilter
|
| 212 |
+
import colorsys
|
| 213 |
+
|
| 214 |
+
# Simple green detection: look for pixels with green hue
|
| 215 |
+
green_mask = self._detect_green_regions(image_np)
|
| 216 |
+
|
| 217 |
+
# Find centroid of green regions, fallback to image center
|
| 218 |
+
if green_mask.sum() > 100: # At least some green pixels
|
| 219 |
+
y_coords, x_coords = np.where(green_mask)
|
| 220 |
+
center_x = int(np.median(x_coords))
|
| 221 |
+
center_y = int(np.median(y_coords))
|
| 222 |
+
else:
|
| 223 |
+
center_x, center_y = w // 2, h // 2
|
| 224 |
+
|
| 225 |
+
# Try multiple points for robustness
|
| 226 |
+
points_to_try = [
|
| 227 |
+
(center_x, center_y), # Green centroid or center
|
| 228 |
+
(w // 2, h // 2), # Image center
|
| 229 |
+
(w // 3, h // 2), # Left third
|
| 230 |
+
(2 * w // 3, h // 2), # Right third
|
| 231 |
+
]
|
| 232 |
+
|
| 233 |
+
best_mask = None
|
| 234 |
+
best_score = -1
|
| 235 |
+
|
| 236 |
+
for px, py in points_to_try:
|
| 237 |
+
point = np.array([[px, py]])
|
| 238 |
+
label = np.array([1])
|
| 239 |
+
|
| 240 |
+
masks, scores, _ = self.predictor.predict(
|
| 241 |
+
point_coords=point,
|
| 242 |
+
point_labels=label,
|
| 243 |
+
multimask_output=True
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
for mask, score in zip(masks, scores):
|
| 247 |
+
# Ensure mask is boolean for indexing
|
| 248 |
+
mask = mask.astype(bool)
|
| 249 |
+
|
| 250 |
+
# Calculate mask coverage
|
| 251 |
+
coverage = mask.sum() / (h * w)
|
| 252 |
+
|
| 253 |
+
# Prefer masks that cover 5-95% of image (more flexible range)
|
| 254 |
+
if 0.05 < coverage < 0.95:
|
| 255 |
+
# Check if mask contains green (likely a leaf)
|
| 256 |
+
green_in_mask = green_mask[mask].sum() / max(mask.sum(), 1)
|
| 257 |
+
|
| 258 |
+
# Bonus for being closer to 30-70% coverage
|
| 259 |
+
coverage_score = 1 - abs(coverage - 0.5)
|
| 260 |
+
|
| 261 |
+
# Combined score: SAM confidence + coverage + greenness
|
| 262 |
+
combined_score = score * 0.5 + coverage_score * 0.2 + green_in_mask * 0.3
|
| 263 |
+
|
| 264 |
+
if combined_score > best_score:
|
| 265 |
+
best_score = combined_score
|
| 266 |
+
best_mask = mask
|
| 267 |
+
|
| 268 |
+
# Fallback to highest score mask from center point
|
| 269 |
+
if best_mask is None:
|
| 270 |
+
center_point = np.array([[w // 2, h // 2]])
|
| 271 |
+
center_label = np.array([1])
|
| 272 |
+
masks, scores, _ = self.predictor.predict(
|
| 273 |
+
point_coords=center_point,
|
| 274 |
+
point_labels=center_label,
|
| 275 |
+
multimask_output=True
|
| 276 |
+
)
|
| 277 |
+
best_idx = np.argmax(scores)
|
| 278 |
+
best_mask = masks[best_idx]
|
| 279 |
+
|
| 280 |
+
# Ensure mask is boolean
|
| 281 |
+
best_mask = best_mask.astype(bool)
|
| 282 |
+
|
| 283 |
+
# Create white background image
|
| 284 |
+
result = np.ones_like(image_np) * 255 # White background
|
| 285 |
+
result[best_mask] = image_np[best_mask] # Copy leaf pixels
|
| 286 |
+
|
| 287 |
+
result_image = Image.fromarray(result.astype(np.uint8))
|
| 288 |
+
|
| 289 |
+
if return_mask:
|
| 290 |
+
return result_image, best_mask
|
| 291 |
+
return result_image
|
| 292 |
+
|
| 293 |
+
def _detect_green_regions(self, image_np: np.ndarray) -> np.ndarray:
|
| 294 |
+
"""Detect green regions in image (likely leaf areas)."""
|
| 295 |
+
# Convert RGB to HSV for better green detection
|
| 296 |
+
r, g, b = image_np[:,:,0], image_np[:,:,1], image_np[:,:,2]
|
| 297 |
+
|
| 298 |
+
# Green typically has: g > r, g > b, and reasonable brightness
|
| 299 |
+
green_mask = (
|
| 300 |
+
(g > r * 0.9) & # Green channel dominant over red
|
| 301 |
+
(g > b * 0.9) & # Green channel dominant over blue
|
| 302 |
+
(g > 40) & # Not too dark
|
| 303 |
+
(g < 250) # Not too bright (white)
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
# Also detect yellow-green (common in leaves)
|
| 307 |
+
yellow_green = (
|
| 308 |
+
(g > 50) &
|
| 309 |
+
(r > 50) &
|
| 310 |
+
(b < r * 0.8) & # Blue much less than red
|
| 311 |
+
(abs(g.astype(int) - r.astype(int)) < 80) # R and G similar
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
return green_mask | yellow_green
|
| 315 |
+
|
| 316 |
+
def refine_boxes_to_masks(
|
| 317 |
+
self,
|
| 318 |
+
image: Image.Image,
|
| 319 |
+
boxes: np.ndarray,
|
| 320 |
+
return_scores: bool = False
|
| 321 |
+
) -> np.ndarray | Tuple[np.ndarray, np.ndarray]:
|
| 322 |
+
"""
|
| 323 |
+
Refine bounding boxes into precise segmentation masks using SAM2.
|
| 324 |
+
|
| 325 |
+
This is used to convert RF-DETR detection boxes into proper
|
| 326 |
+
segmentation masks for disease regions.
|
| 327 |
+
|
| 328 |
+
Args:
|
| 329 |
+
image: PIL Image
|
| 330 |
+
boxes: Array of bounding boxes [N, 4] in xyxy format
|
| 331 |
+
return_scores: If True, also returns confidence scores
|
| 332 |
+
|
| 333 |
+
Returns:
|
| 334 |
+
Array of masks [N, H, W] (and scores if return_scores=True)
|
| 335 |
+
"""
|
| 336 |
+
self.load_model()
|
| 337 |
+
|
| 338 |
+
# Convert to numpy array
|
| 339 |
+
image_np = np.array(image.convert('RGB'))
|
| 340 |
+
h, w = image_np.shape[:2]
|
| 341 |
+
|
| 342 |
+
if len(boxes) == 0:
|
| 343 |
+
empty_masks = np.zeros((0, h, w), dtype=bool)
|
| 344 |
+
if return_scores:
|
| 345 |
+
return empty_masks, np.zeros((0,), dtype=np.float32)
|
| 346 |
+
return empty_masks
|
| 347 |
+
|
| 348 |
+
# Set image for predictor
|
| 349 |
+
self.predictor.set_image(image_np)
|
| 350 |
+
|
| 351 |
+
masks_list = []
|
| 352 |
+
scores_list = []
|
| 353 |
+
|
| 354 |
+
for box in boxes:
|
| 355 |
+
# Use box prompt for SAM2
|
| 356 |
+
box_np = np.array([box])
|
| 357 |
+
|
| 358 |
+
masks, scores, _ = self.predictor.predict(
|
| 359 |
+
box=box_np,
|
| 360 |
+
multimask_output=True
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
# Select best mask (highest score)
|
| 364 |
+
best_idx = np.argmax(scores)
|
| 365 |
+
best_mask = masks[best_idx].astype(bool)
|
| 366 |
+
best_score = scores[best_idx]
|
| 367 |
+
|
| 368 |
+
masks_list.append(best_mask)
|
| 369 |
+
scores_list.append(best_score)
|
| 370 |
+
|
| 371 |
+
result_masks = np.stack(masks_list, axis=0) if masks_list else np.zeros((0, h, w), dtype=bool)
|
| 372 |
+
result_scores = np.array(scores_list, dtype=np.float32)
|
| 373 |
+
|
| 374 |
+
if return_scores:
|
| 375 |
+
return result_masks, result_scores
|
| 376 |
+
return result_masks
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
# Convenience function
|
| 380 |
+
def create_leaf_segmenter(
|
| 381 |
+
checkpoint_path: str = "models/sam2/sam2.1_hiera_small.pt",
|
| 382 |
+
device: Optional[str] = None
|
| 383 |
+
) -> SAM2LeafSegmenter:
|
| 384 |
+
"""Create a SAM2 leaf segmenter instance."""
|
| 385 |
+
return SAM2LeafSegmenter(
|
| 386 |
+
checkpoint_path=checkpoint_path,
|
| 387 |
+
device=device
|
| 388 |
+
)
|
src/pipeline.py
ADDED
|
@@ -0,0 +1,631 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CropDoctor Pipeline Module
|
| 3 |
+
==========================
|
| 4 |
+
|
| 5 |
+
This module integrates all components into a unified diagnostic pipeline:
|
| 6 |
+
1. SAM 3 Segmentation - Detect disease regions
|
| 7 |
+
2. Severity Classification - Assess severity level
|
| 8 |
+
3. Treatment Recommendation - Generate actionable advice
|
| 9 |
+
|
| 10 |
+
Provides both single-image and batch processing capabilities.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import numpy as np
|
| 15 |
+
from PIL import Image
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from typing import Dict, List, Optional, Union, Tuple
|
| 18 |
+
from dataclasses import dataclass, asdict
|
| 19 |
+
import json
|
| 20 |
+
import csv
|
| 21 |
+
from datetime import datetime
|
| 22 |
+
import logging
|
| 23 |
+
|
| 24 |
+
from .sam3_segmentation import SAM3Segmenter, create_segmenter, SegmentationResult
|
| 25 |
+
from .severity_classifier import SeverityClassifier, SeverityPrediction, SEVERITY_LABELS
|
| 26 |
+
from .treatment_recommender import TreatmentRecommender, TreatmentRecommendation
|
| 27 |
+
from .leaf_segmenter import SAM2LeafSegmenter
|
| 28 |
+
|
| 29 |
+
logger = logging.getLogger(__name__)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dataclass
|
| 33 |
+
class DiagnosticResult:
|
| 34 |
+
"""Complete diagnostic result for an image."""
|
| 35 |
+
# Image info
|
| 36 |
+
image_path: str
|
| 37 |
+
timestamp: str
|
| 38 |
+
|
| 39 |
+
# Segmentation results
|
| 40 |
+
num_regions_detected: int
|
| 41 |
+
affected_area_percent: float
|
| 42 |
+
detected_symptoms: List[str]
|
| 43 |
+
|
| 44 |
+
# Severity assessment
|
| 45 |
+
severity_level: int
|
| 46 |
+
severity_label: str
|
| 47 |
+
severity_confidence: float
|
| 48 |
+
|
| 49 |
+
# Treatment recommendations
|
| 50 |
+
disease_name: str
|
| 51 |
+
disease_type: str
|
| 52 |
+
disease_confidence: float
|
| 53 |
+
organic_treatments: List[str]
|
| 54 |
+
chemical_treatments: List[str]
|
| 55 |
+
preventive_measures: List[str]
|
| 56 |
+
treatment_timing: str
|
| 57 |
+
urgency: str
|
| 58 |
+
|
| 59 |
+
# Raw data for further analysis
|
| 60 |
+
segmentation_masks: Optional[np.ndarray] = None
|
| 61 |
+
segmentation_scores: Optional[np.ndarray] = None
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class CropDoctorPipeline:
|
| 65 |
+
"""
|
| 66 |
+
End-to-end pipeline for plant disease diagnosis.
|
| 67 |
+
|
| 68 |
+
Integrates SAM 3 segmentation, severity classification, and
|
| 69 |
+
LLM-based treatment recommendations into a single workflow.
|
| 70 |
+
|
| 71 |
+
Example:
|
| 72 |
+
>>> pipeline = CropDoctorPipeline()
|
| 73 |
+
>>> result = pipeline.diagnose("path/to/leaf.jpg")
|
| 74 |
+
>>> print(f"Disease: {result.disease_name}")
|
| 75 |
+
>>> print(f"Severity: {result.severity_label}")
|
| 76 |
+
>>> print(f"Treatment: {result.organic_treatments[0]}")
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
def __init__(
|
| 80 |
+
self,
|
| 81 |
+
sam3_checkpoint: str = "models/sam3/sam3.pt",
|
| 82 |
+
sam3_config: str = "configs/sam3_config.yaml",
|
| 83 |
+
classifier_checkpoint: Optional[str] = None,
|
| 84 |
+
use_llm: bool = True,
|
| 85 |
+
anthropic_api_key: Optional[str] = None,
|
| 86 |
+
device: Optional[str] = None,
|
| 87 |
+
use_mock_sam3: bool = False,
|
| 88 |
+
use_rfdetr: bool = False,
|
| 89 |
+
rfdetr_checkpoint: str = "models/rfdetr/checkpoint_best_total.pth",
|
| 90 |
+
rfdetr_model_size: str = "medium",
|
| 91 |
+
use_leaf_segmentation: bool = False,
|
| 92 |
+
sam2_checkpoint: str = "models/sam2/sam2.1_hiera_small.pt"
|
| 93 |
+
):
|
| 94 |
+
"""
|
| 95 |
+
Initialize the CropDoctor pipeline.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
sam3_checkpoint: Path to SAM 3 model checkpoint
|
| 99 |
+
sam3_config: Path to SAM 3 configuration
|
| 100 |
+
classifier_checkpoint: Path to severity classifier checkpoint
|
| 101 |
+
use_llm: Whether to use Claude API for recommendations
|
| 102 |
+
anthropic_api_key: Optional API key for Claude
|
| 103 |
+
device: Device to use (auto-detected if None)
|
| 104 |
+
use_mock_sam3: Use mock SAM 3 for testing without model
|
| 105 |
+
use_rfdetr: Use RF-DETR for detection (recommended)
|
| 106 |
+
rfdetr_checkpoint: Path to trained RF-DETR model
|
| 107 |
+
rfdetr_model_size: RF-DETR model size (nano, small, medium, base)
|
| 108 |
+
use_leaf_segmentation: Use SAM2 to segment leaf before detection
|
| 109 |
+
sam2_checkpoint: Path to SAM2 checkpoint for leaf segmentation
|
| 110 |
+
"""
|
| 111 |
+
# Set device
|
| 112 |
+
if device is None:
|
| 113 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 114 |
+
else:
|
| 115 |
+
self.device = device
|
| 116 |
+
|
| 117 |
+
logger.info(f"Initializing CropDoctor Pipeline on {self.device}")
|
| 118 |
+
|
| 119 |
+
# Initialize components (lazy loading)
|
| 120 |
+
self._segmenter = None
|
| 121 |
+
self._classifier = None
|
| 122 |
+
self._recommender = None
|
| 123 |
+
self._leaf_segmenter = None
|
| 124 |
+
|
| 125 |
+
# Store config
|
| 126 |
+
self.sam3_checkpoint = sam3_checkpoint
|
| 127 |
+
self.sam3_config = sam3_config
|
| 128 |
+
self.classifier_checkpoint = classifier_checkpoint
|
| 129 |
+
self.use_llm = use_llm
|
| 130 |
+
self.anthropic_api_key = anthropic_api_key
|
| 131 |
+
self.use_mock_sam3 = use_mock_sam3
|
| 132 |
+
self.use_rfdetr = use_rfdetr
|
| 133 |
+
self.rfdetr_checkpoint = rfdetr_checkpoint
|
| 134 |
+
self.rfdetr_model_size = rfdetr_model_size
|
| 135 |
+
self.use_leaf_segmentation = use_leaf_segmentation
|
| 136 |
+
self.sam2_checkpoint = sam2_checkpoint
|
| 137 |
+
|
| 138 |
+
# Default prompts for disease detection
|
| 139 |
+
self.disease_prompts = [
|
| 140 |
+
"diseased plant tissue",
|
| 141 |
+
"leaf with brown spots",
|
| 142 |
+
"powdery mildew",
|
| 143 |
+
"yellowing leaves",
|
| 144 |
+
"pest damage",
|
| 145 |
+
"wilted tissue",
|
| 146 |
+
"healthy green leaf"
|
| 147 |
+
]
|
| 148 |
+
|
| 149 |
+
@property
|
| 150 |
+
def segmenter(self) -> SAM3Segmenter:
|
| 151 |
+
"""Lazy load segmenter (SAM 3, Mock, or RF-DETR)."""
|
| 152 |
+
if self._segmenter is None:
|
| 153 |
+
self._segmenter = create_segmenter(
|
| 154 |
+
self.sam3_checkpoint,
|
| 155 |
+
self.sam3_config,
|
| 156 |
+
use_mock=self.use_mock_sam3,
|
| 157 |
+
use_rfdetr=self.use_rfdetr,
|
| 158 |
+
rfdetr_checkpoint=self.rfdetr_checkpoint,
|
| 159 |
+
rfdetr_model_size=self.rfdetr_model_size
|
| 160 |
+
)
|
| 161 |
+
return self._segmenter
|
| 162 |
+
|
| 163 |
+
@property
|
| 164 |
+
def classifier(self) -> SeverityClassifier:
|
| 165 |
+
"""Lazy load severity classifier."""
|
| 166 |
+
if self._classifier is None:
|
| 167 |
+
self._classifier = SeverityClassifier(
|
| 168 |
+
checkpoint_path=self.classifier_checkpoint,
|
| 169 |
+
device=self.device
|
| 170 |
+
)
|
| 171 |
+
return self._classifier
|
| 172 |
+
|
| 173 |
+
@property
|
| 174 |
+
def recommender(self) -> TreatmentRecommender:
|
| 175 |
+
"""Lazy load treatment recommender."""
|
| 176 |
+
if self._recommender is None:
|
| 177 |
+
self._recommender = TreatmentRecommender(
|
| 178 |
+
api_key=self.anthropic_api_key,
|
| 179 |
+
use_llm=self.use_llm
|
| 180 |
+
)
|
| 181 |
+
return self._recommender
|
| 182 |
+
|
| 183 |
+
@property
|
| 184 |
+
def leaf_segmenter(self) -> SAM2LeafSegmenter:
|
| 185 |
+
"""Lazy load SAM2 leaf segmenter."""
|
| 186 |
+
if self._leaf_segmenter is None:
|
| 187 |
+
self._leaf_segmenter = SAM2LeafSegmenter(
|
| 188 |
+
checkpoint_path=self.sam2_checkpoint,
|
| 189 |
+
device=self.device
|
| 190 |
+
)
|
| 191 |
+
return self._leaf_segmenter
|
| 192 |
+
|
| 193 |
+
def diagnose(
|
| 194 |
+
self,
|
| 195 |
+
image: Union[str, Path, Image.Image, np.ndarray],
|
| 196 |
+
plant_species: Optional[str] = None,
|
| 197 |
+
analysis_profile: str = "standard",
|
| 198 |
+
custom_prompts: Optional[List[str]] = None,
|
| 199 |
+
return_masks: bool = False
|
| 200 |
+
) -> DiagnosticResult:
|
| 201 |
+
"""
|
| 202 |
+
Perform complete diagnosis on an image.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
image: Input image (path, PIL Image, or numpy array)
|
| 206 |
+
plant_species: Optional plant species for context
|
| 207 |
+
analysis_profile: SAM 3 analysis profile
|
| 208 |
+
custom_prompts: Optional custom prompts for segmentation
|
| 209 |
+
return_masks: Whether to include segmentation masks in result
|
| 210 |
+
|
| 211 |
+
Returns:
|
| 212 |
+
DiagnosticResult with complete diagnosis
|
| 213 |
+
"""
|
| 214 |
+
timestamp = datetime.now().isoformat()
|
| 215 |
+
image_path = str(image) if isinstance(image, (str, Path)) else "in_memory"
|
| 216 |
+
|
| 217 |
+
logger.info(f"Starting diagnosis for {image_path}")
|
| 218 |
+
|
| 219 |
+
# Load image if needed
|
| 220 |
+
if isinstance(image, (str, Path)):
|
| 221 |
+
pil_image = Image.open(image).convert("RGB")
|
| 222 |
+
elif isinstance(image, np.ndarray):
|
| 223 |
+
pil_image = Image.fromarray(image)
|
| 224 |
+
else:
|
| 225 |
+
pil_image = image
|
| 226 |
+
|
| 227 |
+
# Store original image for visualization
|
| 228 |
+
original_image = pil_image
|
| 229 |
+
leaf_mask = None
|
| 230 |
+
|
| 231 |
+
# Step 0 (optional): Leaf segmentation with SAM2
|
| 232 |
+
if self.use_leaf_segmentation:
|
| 233 |
+
logger.info("Step 0: Leaf segmentation with SAM2")
|
| 234 |
+
pil_image, leaf_mask = self.leaf_segmenter.auto_segment_leaf(
|
| 235 |
+
pil_image,
|
| 236 |
+
return_mask=True
|
| 237 |
+
)
|
| 238 |
+
logger.info("Leaf isolated from background")
|
| 239 |
+
|
| 240 |
+
# Step 1: Disease Detection (SAM3/RF-DETR)
|
| 241 |
+
logger.info("Step 1: Disease detection")
|
| 242 |
+
prompts = custom_prompts or self.disease_prompts
|
| 243 |
+
seg_result = self.segmenter.segment_with_concepts(pil_image, prompts)
|
| 244 |
+
|
| 245 |
+
# Step 1.5: Refine bounding boxes to proper masks using SAM2
|
| 246 |
+
# This converts RF-DETR rectangular boxes to precise segmentation masks
|
| 247 |
+
if self.use_rfdetr and len(seg_result.boxes) > 0:
|
| 248 |
+
logger.info("Step 1.5: Refining detection boxes with SAM2")
|
| 249 |
+
refined_masks = self.leaf_segmenter.refine_boxes_to_masks(
|
| 250 |
+
pil_image,
|
| 251 |
+
seg_result.boxes
|
| 252 |
+
)
|
| 253 |
+
# Replace rectangular masks with refined masks
|
| 254 |
+
seg_result = SegmentationResult(
|
| 255 |
+
masks=refined_masks,
|
| 256 |
+
boxes=seg_result.boxes,
|
| 257 |
+
scores=seg_result.scores,
|
| 258 |
+
prompts=seg_result.prompts,
|
| 259 |
+
prompt_indices=seg_result.prompt_indices
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
# Calculate affected area
|
| 263 |
+
area_stats = self.segmenter.calculate_affected_area(
|
| 264 |
+
seg_result,
|
| 265 |
+
healthy_prompt_idx=prompts.index("healthy green leaf") if "healthy green leaf" in prompts else None
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
# Get detected symptoms (prompts with detections)
|
| 269 |
+
detected_symptoms = []
|
| 270 |
+
for prompt_idx in np.unique(seg_result.prompt_indices):
|
| 271 |
+
if prompt_idx < len(prompts):
|
| 272 |
+
prompt = prompts[prompt_idx]
|
| 273 |
+
if prompt != "healthy green leaf":
|
| 274 |
+
detected_symptoms.append(prompt)
|
| 275 |
+
|
| 276 |
+
# Step 2: Severity Classification
|
| 277 |
+
logger.info("Step 2: Severity classification")
|
| 278 |
+
|
| 279 |
+
# Create combined disease mask for classification
|
| 280 |
+
if len(seg_result.masks) > 0:
|
| 281 |
+
# Combine all disease masks (excluding healthy)
|
| 282 |
+
disease_mask = np.zeros(seg_result.masks[0].shape, dtype=bool)
|
| 283 |
+
for i, mask in enumerate(seg_result.masks):
|
| 284 |
+
prompt_idx = seg_result.prompt_indices[i]
|
| 285 |
+
if prompt_idx < len(prompts) and prompts[prompt_idx] != "healthy green leaf":
|
| 286 |
+
disease_mask |= mask
|
| 287 |
+
else:
|
| 288 |
+
disease_mask = None
|
| 289 |
+
|
| 290 |
+
severity_result = self.classifier.classify(pil_image, mask=disease_mask)
|
| 291 |
+
|
| 292 |
+
# Override severity based on affected area if needed
|
| 293 |
+
if area_stats["total_affected_percent"] < 1:
|
| 294 |
+
severity_result = SeverityPrediction(
|
| 295 |
+
severity_level=0,
|
| 296 |
+
severity_label="healthy",
|
| 297 |
+
confidence=0.9,
|
| 298 |
+
probabilities={"healthy": 0.9, "mild": 0.05, "moderate": 0.03, "severe": 0.02},
|
| 299 |
+
affected_area_percent=area_stats["total_affected_percent"]
|
| 300 |
+
)
|
| 301 |
+
elif area_stats["total_affected_percent"] < 10 and severity_result.severity_level > 1:
|
| 302 |
+
severity_result.severity_level = 1
|
| 303 |
+
severity_result.severity_label = "mild"
|
| 304 |
+
|
| 305 |
+
# Step 3: Treatment Recommendations
|
| 306 |
+
logger.info("Step 3: Generating treatment recommendations")
|
| 307 |
+
|
| 308 |
+
if detected_symptoms and severity_result.severity_level > 0:
|
| 309 |
+
treatment_result = self.recommender.get_recommendation(
|
| 310 |
+
symptoms=detected_symptoms,
|
| 311 |
+
severity=severity_result.severity_label,
|
| 312 |
+
plant_species=plant_species,
|
| 313 |
+
affected_area_percent=area_stats["total_affected_percent"]
|
| 314 |
+
)
|
| 315 |
+
else:
|
| 316 |
+
# Healthy plant - no treatment needed
|
| 317 |
+
treatment_result = TreatmentRecommendation(
|
| 318 |
+
disease_name="No Disease Detected",
|
| 319 |
+
disease_type="healthy",
|
| 320 |
+
confidence=0.9,
|
| 321 |
+
symptoms_matched=[],
|
| 322 |
+
organic_treatments=["Continue regular care"],
|
| 323 |
+
chemical_treatments=[],
|
| 324 |
+
preventive_measures=[
|
| 325 |
+
"Maintain good air circulation",
|
| 326 |
+
"Water at soil level",
|
| 327 |
+
"Monitor regularly for early symptoms"
|
| 328 |
+
],
|
| 329 |
+
timing="Regular monitoring recommended",
|
| 330 |
+
urgency="low",
|
| 331 |
+
additional_notes="Plant appears healthy. Continue preventive care."
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
# Compile final result
|
| 335 |
+
result = DiagnosticResult(
|
| 336 |
+
image_path=image_path,
|
| 337 |
+
timestamp=timestamp,
|
| 338 |
+
num_regions_detected=len(seg_result.masks),
|
| 339 |
+
affected_area_percent=area_stats["total_affected_percent"],
|
| 340 |
+
detected_symptoms=detected_symptoms,
|
| 341 |
+
severity_level=severity_result.severity_level,
|
| 342 |
+
severity_label=severity_result.severity_label,
|
| 343 |
+
severity_confidence=severity_result.confidence,
|
| 344 |
+
disease_name=treatment_result.disease_name,
|
| 345 |
+
disease_type=treatment_result.disease_type,
|
| 346 |
+
disease_confidence=treatment_result.confidence,
|
| 347 |
+
organic_treatments=treatment_result.organic_treatments,
|
| 348 |
+
chemical_treatments=treatment_result.chemical_treatments,
|
| 349 |
+
preventive_measures=treatment_result.preventive_measures,
|
| 350 |
+
treatment_timing=treatment_result.timing,
|
| 351 |
+
urgency=treatment_result.urgency,
|
| 352 |
+
segmentation_masks=seg_result.masks if return_masks else None,
|
| 353 |
+
segmentation_scores=seg_result.scores if return_masks else None
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
logger.info(f"Diagnosis complete: {result.disease_name} ({result.severity_label})")
|
| 357 |
+
|
| 358 |
+
return result
|
| 359 |
+
|
| 360 |
+
def batch_diagnose(
|
| 361 |
+
self,
|
| 362 |
+
image_folder: Union[str, Path],
|
| 363 |
+
output_dir: Optional[Union[str, Path]] = None,
|
| 364 |
+
plant_species: Optional[str] = None,
|
| 365 |
+
save_visualizations: bool = True,
|
| 366 |
+
file_extensions: List[str] = [".jpg", ".jpeg", ".png"]
|
| 367 |
+
) -> List[DiagnosticResult]:
|
| 368 |
+
"""
|
| 369 |
+
Process multiple images from a folder.
|
| 370 |
+
|
| 371 |
+
Args:
|
| 372 |
+
image_folder: Path to folder containing images
|
| 373 |
+
output_dir: Where to save results (optional)
|
| 374 |
+
plant_species: Plant species for all images
|
| 375 |
+
save_visualizations: Whether to save visualization images
|
| 376 |
+
file_extensions: Image file extensions to process
|
| 377 |
+
|
| 378 |
+
Returns:
|
| 379 |
+
List of DiagnosticResult for each image
|
| 380 |
+
"""
|
| 381 |
+
image_folder = Path(image_folder)
|
| 382 |
+
|
| 383 |
+
if output_dir:
|
| 384 |
+
output_dir = Path(output_dir)
|
| 385 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 386 |
+
|
| 387 |
+
# Find all images
|
| 388 |
+
images = []
|
| 389 |
+
for ext in file_extensions:
|
| 390 |
+
images.extend(image_folder.glob(f"*{ext}"))
|
| 391 |
+
images.extend(image_folder.glob(f"*{ext.upper()}"))
|
| 392 |
+
|
| 393 |
+
logger.info(f"Found {len(images)} images to process")
|
| 394 |
+
|
| 395 |
+
results = []
|
| 396 |
+
|
| 397 |
+
for i, img_path in enumerate(images):
|
| 398 |
+
logger.info(f"Processing {i+1}/{len(images)}: {img_path.name}")
|
| 399 |
+
|
| 400 |
+
try:
|
| 401 |
+
result = self.diagnose(
|
| 402 |
+
img_path,
|
| 403 |
+
plant_species=plant_species,
|
| 404 |
+
return_masks=save_visualizations
|
| 405 |
+
)
|
| 406 |
+
results.append(result)
|
| 407 |
+
|
| 408 |
+
# Save visualization if requested
|
| 409 |
+
if save_visualizations and output_dir:
|
| 410 |
+
self._save_visualization(
|
| 411 |
+
img_path,
|
| 412 |
+
result,
|
| 413 |
+
output_dir / f"{img_path.stem}_diagnosis.png"
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
except Exception as e:
|
| 417 |
+
logger.error(f"Error processing {img_path}: {e}")
|
| 418 |
+
continue
|
| 419 |
+
|
| 420 |
+
logger.info(f"Batch processing complete: {len(results)}/{len(images)} successful")
|
| 421 |
+
|
| 422 |
+
return results
|
| 423 |
+
|
| 424 |
+
def _save_visualization(
|
| 425 |
+
self,
|
| 426 |
+
image_path: Path,
|
| 427 |
+
result: DiagnosticResult,
|
| 428 |
+
output_path: Path
|
| 429 |
+
):
|
| 430 |
+
"""Save diagnostic visualization."""
|
| 431 |
+
# Import visualization module
|
| 432 |
+
from .visualization import create_diagnostic_visualization
|
| 433 |
+
|
| 434 |
+
image = Image.open(image_path)
|
| 435 |
+
|
| 436 |
+
fig = create_diagnostic_visualization(
|
| 437 |
+
image,
|
| 438 |
+
result.segmentation_masks,
|
| 439 |
+
result.severity_label,
|
| 440 |
+
result.disease_name,
|
| 441 |
+
result.affected_area_percent
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
fig.savefig(output_path, dpi=150, bbox_inches='tight')
|
| 445 |
+
import matplotlib.pyplot as plt
|
| 446 |
+
plt.close(fig)
|
| 447 |
+
|
| 448 |
+
def export_report(
|
| 449 |
+
self,
|
| 450 |
+
results: List[DiagnosticResult],
|
| 451 |
+
output_path: Union[str, Path],
|
| 452 |
+
format: str = "csv"
|
| 453 |
+
):
|
| 454 |
+
"""
|
| 455 |
+
Export results to file.
|
| 456 |
+
|
| 457 |
+
Args:
|
| 458 |
+
results: List of diagnostic results
|
| 459 |
+
output_path: Output file path
|
| 460 |
+
format: Output format ("csv", "json")
|
| 461 |
+
"""
|
| 462 |
+
output_path = Path(output_path)
|
| 463 |
+
|
| 464 |
+
if format == "csv":
|
| 465 |
+
self._export_csv(results, output_path)
|
| 466 |
+
elif format == "json":
|
| 467 |
+
self._export_json(results, output_path)
|
| 468 |
+
else:
|
| 469 |
+
raise ValueError(f"Unknown format: {format}")
|
| 470 |
+
|
| 471 |
+
logger.info(f"Report exported to {output_path}")
|
| 472 |
+
|
| 473 |
+
def _export_csv(self, results: List[DiagnosticResult], output_path: Path):
|
| 474 |
+
"""Export to CSV."""
|
| 475 |
+
with open(output_path, 'w', newline='') as f:
|
| 476 |
+
if not results:
|
| 477 |
+
return
|
| 478 |
+
|
| 479 |
+
# Get fields (excluding numpy arrays)
|
| 480 |
+
fields = [
|
| 481 |
+
'image_path', 'timestamp', 'num_regions_detected',
|
| 482 |
+
'affected_area_percent', 'detected_symptoms',
|
| 483 |
+
'severity_level', 'severity_label', 'severity_confidence',
|
| 484 |
+
'disease_name', 'disease_type', 'disease_confidence',
|
| 485 |
+
'organic_treatments', 'urgency'
|
| 486 |
+
]
|
| 487 |
+
|
| 488 |
+
writer = csv.DictWriter(f, fieldnames=fields)
|
| 489 |
+
writer.writeheader()
|
| 490 |
+
|
| 491 |
+
for result in results:
|
| 492 |
+
row = {
|
| 493 |
+
'image_path': result.image_path,
|
| 494 |
+
'timestamp': result.timestamp,
|
| 495 |
+
'num_regions_detected': result.num_regions_detected,
|
| 496 |
+
'affected_area_percent': f"{result.affected_area_percent:.2f}",
|
| 497 |
+
'detected_symptoms': '; '.join(result.detected_symptoms),
|
| 498 |
+
'severity_level': result.severity_level,
|
| 499 |
+
'severity_label': result.severity_label,
|
| 500 |
+
'severity_confidence': f"{result.severity_confidence:.3f}",
|
| 501 |
+
'disease_name': result.disease_name,
|
| 502 |
+
'disease_type': result.disease_type,
|
| 503 |
+
'disease_confidence': f"{result.disease_confidence:.3f}",
|
| 504 |
+
'organic_treatments': '; '.join(result.organic_treatments[:3]),
|
| 505 |
+
'urgency': result.urgency
|
| 506 |
+
}
|
| 507 |
+
writer.writerow(row)
|
| 508 |
+
|
| 509 |
+
def _export_json(self, results: List[DiagnosticResult], output_path: Path):
|
| 510 |
+
"""Export to JSON."""
|
| 511 |
+
data = []
|
| 512 |
+
for result in results:
|
| 513 |
+
d = asdict(result)
|
| 514 |
+
# Remove numpy arrays
|
| 515 |
+
d.pop('segmentation_masks', None)
|
| 516 |
+
d.pop('segmentation_scores', None)
|
| 517 |
+
data.append(d)
|
| 518 |
+
|
| 519 |
+
with open(output_path, 'w') as f:
|
| 520 |
+
json.dump(data, f, indent=2, default=str)
|
| 521 |
+
|
| 522 |
+
def generate_summary_report(
|
| 523 |
+
self,
|
| 524 |
+
results: List[DiagnosticResult]
|
| 525 |
+
) -> str:
|
| 526 |
+
"""
|
| 527 |
+
Generate a summary report for batch results.
|
| 528 |
+
|
| 529 |
+
Args:
|
| 530 |
+
results: List of diagnostic results
|
| 531 |
+
|
| 532 |
+
Returns:
|
| 533 |
+
Formatted summary report string
|
| 534 |
+
"""
|
| 535 |
+
if not results:
|
| 536 |
+
return "No results to summarize."
|
| 537 |
+
|
| 538 |
+
# Calculate statistics
|
| 539 |
+
total = len(results)
|
| 540 |
+
healthy = sum(1 for r in results if r.severity_level == 0)
|
| 541 |
+
mild = sum(1 for r in results if r.severity_level == 1)
|
| 542 |
+
moderate = sum(1 for r in results if r.severity_level == 2)
|
| 543 |
+
severe = sum(1 for r in results if r.severity_level == 3)
|
| 544 |
+
|
| 545 |
+
# Disease frequency
|
| 546 |
+
disease_counts = {}
|
| 547 |
+
for r in results:
|
| 548 |
+
disease_counts[r.disease_name] = disease_counts.get(r.disease_name, 0) + 1
|
| 549 |
+
|
| 550 |
+
# Average affected area
|
| 551 |
+
avg_affected = np.mean([r.affected_area_percent for r in results])
|
| 552 |
+
|
| 553 |
+
report = f"""
|
| 554 |
+
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 555 |
+
β BATCH DIAGNOSIS SUMMARY β
|
| 556 |
+
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 557 |
+
|
| 558 |
+
π OVERALL STATISTICS
|
| 559 |
+
Total Images Analyzed: {total}
|
| 560 |
+
|
| 561 |
+
Severity Distribution:
|
| 562 |
+
βββ π’ Healthy: {healthy} ({healthy/total*100:.1f}%)
|
| 563 |
+
βββ π‘ Mild: {mild} ({mild/total*100:.1f}%)
|
| 564 |
+
βββ π Moderate: {moderate} ({moderate/total*100:.1f}%)
|
| 565 |
+
βββ π΄ Severe: {severe} ({severe/total*100:.1f}%)
|
| 566 |
+
|
| 567 |
+
Average Affected Area: {avg_affected:.1f}%
|
| 568 |
+
|
| 569 |
+
π¦ DISEASE FREQUENCY
|
| 570 |
+
"""
|
| 571 |
+
for disease, count in sorted(disease_counts.items(), key=lambda x: -x[1]):
|
| 572 |
+
report += f" β’ {disease}: {count} ({count/total*100:.1f}%)\n"
|
| 573 |
+
|
| 574 |
+
# Urgent cases
|
| 575 |
+
urgent = [r for r in results if r.urgency in ['high', 'critical']]
|
| 576 |
+
if urgent:
|
| 577 |
+
report += f"""
|
| 578 |
+
β οΈ URGENT ATTENTION REQUIRED
|
| 579 |
+
{len(urgent)} images require immediate attention:
|
| 580 |
+
"""
|
| 581 |
+
for r in urgent[:5]: # Show top 5
|
| 582 |
+
report += f" β’ {Path(r.image_path).name}: {r.disease_name} ({r.urgency})\n"
|
| 583 |
+
|
| 584 |
+
report += """
|
| 585 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 586 |
+
"""
|
| 587 |
+
return report
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
def quick_diagnose(
|
| 591 |
+
image_path: str,
|
| 592 |
+
use_mock: bool = True
|
| 593 |
+
) -> DiagnosticResult:
|
| 594 |
+
"""
|
| 595 |
+
Quick diagnosis function for simple use cases.
|
| 596 |
+
|
| 597 |
+
Args:
|
| 598 |
+
image_path: Path to image
|
| 599 |
+
use_mock: Use mock models (for testing without SAM 3)
|
| 600 |
+
|
| 601 |
+
Returns:
|
| 602 |
+
DiagnosticResult
|
| 603 |
+
"""
|
| 604 |
+
pipeline = CropDoctorPipeline(
|
| 605 |
+
use_mock_sam3=use_mock,
|
| 606 |
+
use_llm=False
|
| 607 |
+
)
|
| 608 |
+
return pipeline.diagnose(image_path)
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
if __name__ == "__main__":
|
| 612 |
+
# Test the pipeline with mock
|
| 613 |
+
print("Testing CropDoctor Pipeline...")
|
| 614 |
+
|
| 615 |
+
# Create test image
|
| 616 |
+
test_img = Image.new("RGB", (640, 480), color=(139, 69, 19))
|
| 617 |
+
test_img.save("/tmp/test_leaf.jpg")
|
| 618 |
+
|
| 619 |
+
# Run pipeline
|
| 620 |
+
pipeline = CropDoctorPipeline(use_mock_sam3=True, use_llm=False)
|
| 621 |
+
result = pipeline.diagnose("/tmp/test_leaf.jpg")
|
| 622 |
+
|
| 623 |
+
print(f"\nπ Diagnosis Results:")
|
| 624 |
+
print(f" Disease: {result.disease_name}")
|
| 625 |
+
print(f" Type: {result.disease_type}")
|
| 626 |
+
print(f" Severity: {result.severity_label} (Level {result.severity_level})")
|
| 627 |
+
print(f" Affected Area: {result.affected_area_percent:.1f}%")
|
| 628 |
+
print(f" Urgency: {result.urgency}")
|
| 629 |
+
print(f"\nπΏ Recommended Treatments:")
|
| 630 |
+
for t in result.organic_treatments[:3]:
|
| 631 |
+
print(f" β’ {t}")
|
src/sam3_segmentation.py
ADDED
|
@@ -0,0 +1,864 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SAM 3 Segmentation Module for CropDoctor-Semantic
|
| 3 |
+
==================================================
|
| 4 |
+
|
| 5 |
+
This module provides the core segmentation functionality using Meta's SAM 3
|
| 6 |
+
for concept-based plant disease detection.
|
| 7 |
+
|
| 8 |
+
SAM 3 enables zero-shot segmentation using natural language prompts,
|
| 9 |
+
allowing detection of disease symptoms without task-specific training.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import numpy as np
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import List, Dict, Tuple, Optional, Union
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
import yaml
|
| 19 |
+
import logging
|
| 20 |
+
|
| 21 |
+
# Configure logging
|
| 22 |
+
logging.basicConfig(level=logging.INFO)
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class SegmentationResult:
|
| 28 |
+
"""Container for segmentation results."""
|
| 29 |
+
masks: np.ndarray # Shape: (N, H, W) boolean masks
|
| 30 |
+
boxes: np.ndarray # Shape: (N, 4) bounding boxes [x1, y1, x2, y2]
|
| 31 |
+
scores: np.ndarray # Shape: (N,) confidence scores
|
| 32 |
+
prompts: List[str] # Prompts used for each detection
|
| 33 |
+
prompt_indices: np.ndarray # Which prompt each mask corresponds to
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class SAM3Segmenter:
|
| 37 |
+
"""
|
| 38 |
+
SAM 3 based segmentation for plant disease detection.
|
| 39 |
+
|
| 40 |
+
Uses text prompts to detect and segment disease symptoms in plant images.
|
| 41 |
+
SAM 3's Promptable Concept Segmentation (PCS) enables open-vocabulary
|
| 42 |
+
detection without fine-tuning.
|
| 43 |
+
|
| 44 |
+
Example:
|
| 45 |
+
>>> segmenter = SAM3Segmenter("models/sam3/sam3.pt")
|
| 46 |
+
>>> result = segmenter.segment_with_concepts(
|
| 47 |
+
... "leaf_image.jpg",
|
| 48 |
+
... ["leaf with brown spots", "healthy leaf"]
|
| 49 |
+
... )
|
| 50 |
+
>>> print(f"Found {len(result.masks)} regions")
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
checkpoint_path: str = "models/sam3/sam3.pt",
|
| 56 |
+
config_path: str = "configs/sam3_config.yaml",
|
| 57 |
+
device: Optional[str] = None
|
| 58 |
+
):
|
| 59 |
+
"""
|
| 60 |
+
Initialize SAM 3 segmenter.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
checkpoint_path: Path to SAM 3 checkpoint
|
| 64 |
+
config_path: Path to configuration file
|
| 65 |
+
device: Device to use (cuda, cpu, mps). Auto-detected if None.
|
| 66 |
+
"""
|
| 67 |
+
self.checkpoint_path = Path(checkpoint_path)
|
| 68 |
+
self.config = self._load_config(config_path)
|
| 69 |
+
|
| 70 |
+
# Set device
|
| 71 |
+
if device is None:
|
| 72 |
+
if torch.cuda.is_available():
|
| 73 |
+
self.device = "cuda"
|
| 74 |
+
elif torch.backends.mps.is_available():
|
| 75 |
+
self.device = "mps"
|
| 76 |
+
else:
|
| 77 |
+
self.device = "cpu"
|
| 78 |
+
else:
|
| 79 |
+
self.device = device
|
| 80 |
+
|
| 81 |
+
logger.info(f"Using device: {self.device}")
|
| 82 |
+
|
| 83 |
+
# Model will be loaded lazily
|
| 84 |
+
self.model = None
|
| 85 |
+
self.processor = None
|
| 86 |
+
|
| 87 |
+
def _load_config(self, config_path: str) -> dict:
|
| 88 |
+
"""Load configuration from YAML file."""
|
| 89 |
+
config_path = Path(config_path)
|
| 90 |
+
if config_path.exists():
|
| 91 |
+
with open(config_path, 'r') as f:
|
| 92 |
+
return yaml.safe_load(f)
|
| 93 |
+
else:
|
| 94 |
+
logger.warning(f"Config not found at {config_path}, using defaults")
|
| 95 |
+
return self._default_config()
|
| 96 |
+
|
| 97 |
+
def _default_config(self) -> dict:
|
| 98 |
+
"""Return default configuration."""
|
| 99 |
+
return {
|
| 100 |
+
"inference": {
|
| 101 |
+
"confidence_threshold": 0.25,
|
| 102 |
+
"presence_threshold": 0.5,
|
| 103 |
+
"max_objects_per_prompt": 50,
|
| 104 |
+
"min_mask_area": 100
|
| 105 |
+
},
|
| 106 |
+
"visualization": {
|
| 107 |
+
"mask_alpha": 0.5,
|
| 108 |
+
"show_confidence": True
|
| 109 |
+
}
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
def load_model(self):
|
| 113 |
+
"""Load SAM 3 model and processor."""
|
| 114 |
+
if self.model is not None:
|
| 115 |
+
return
|
| 116 |
+
|
| 117 |
+
logger.info("Loading SAM 3 model...")
|
| 118 |
+
|
| 119 |
+
try:
|
| 120 |
+
# Import SAM 3 modules
|
| 121 |
+
from sam3.model_builder import build_sam3_image_model
|
| 122 |
+
from sam3.model.sam3_image_processor import Sam3Processor
|
| 123 |
+
|
| 124 |
+
# Build model
|
| 125 |
+
self.model = build_sam3_image_model(checkpoint=str(self.checkpoint_path))
|
| 126 |
+
self.model.to(self.device)
|
| 127 |
+
|
| 128 |
+
if self.config.get("model", {}).get("half_precision", True) and self.device == "cuda":
|
| 129 |
+
self.model = self.model.half()
|
| 130 |
+
|
| 131 |
+
self.model.eval()
|
| 132 |
+
|
| 133 |
+
# Create processor
|
| 134 |
+
self.processor = Sam3Processor(self.model)
|
| 135 |
+
|
| 136 |
+
logger.info("SAM 3 model loaded successfully")
|
| 137 |
+
|
| 138 |
+
except ImportError:
|
| 139 |
+
logger.error("SAM 3 not installed. Please install from: https://github.com/facebookresearch/sam3")
|
| 140 |
+
raise
|
| 141 |
+
except FileNotFoundError:
|
| 142 |
+
logger.error(f"Checkpoint not found at {self.checkpoint_path}")
|
| 143 |
+
raise
|
| 144 |
+
|
| 145 |
+
def segment_with_concepts(
|
| 146 |
+
self,
|
| 147 |
+
image: Union[str, Path, Image.Image, np.ndarray],
|
| 148 |
+
text_prompts: List[str],
|
| 149 |
+
confidence_threshold: Optional[float] = None
|
| 150 |
+
) -> SegmentationResult:
|
| 151 |
+
"""
|
| 152 |
+
Segment image using text prompts.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
image: Input image (path, PIL Image, or numpy array)
|
| 156 |
+
text_prompts: List of text prompts describing concepts to detect
|
| 157 |
+
confidence_threshold: Override default confidence threshold
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
SegmentationResult containing masks, boxes, scores, and prompt info
|
| 161 |
+
"""
|
| 162 |
+
# Ensure model is loaded
|
| 163 |
+
self.load_model()
|
| 164 |
+
|
| 165 |
+
# Load image
|
| 166 |
+
if isinstance(image, (str, Path)):
|
| 167 |
+
image = Image.open(image).convert("RGB")
|
| 168 |
+
elif isinstance(image, np.ndarray):
|
| 169 |
+
image = Image.fromarray(image)
|
| 170 |
+
|
| 171 |
+
# Get threshold
|
| 172 |
+
threshold = confidence_threshold or self.config["inference"]["confidence_threshold"]
|
| 173 |
+
|
| 174 |
+
# Set image in processor
|
| 175 |
+
inference_state = self.processor.set_image(image)
|
| 176 |
+
|
| 177 |
+
# Collect results from all prompts
|
| 178 |
+
all_masks = []
|
| 179 |
+
all_boxes = []
|
| 180 |
+
all_scores = []
|
| 181 |
+
all_prompt_indices = []
|
| 182 |
+
|
| 183 |
+
for prompt_idx, prompt in enumerate(text_prompts):
|
| 184 |
+
logger.debug(f"Processing prompt: {prompt}")
|
| 185 |
+
|
| 186 |
+
# Get segmentation for this prompt
|
| 187 |
+
output = self.processor.set_text_prompt(
|
| 188 |
+
state=inference_state,
|
| 189 |
+
prompt=prompt
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
masks = output["masks"]
|
| 193 |
+
boxes = output["boxes"]
|
| 194 |
+
scores = output["scores"]
|
| 195 |
+
|
| 196 |
+
if masks is not None and len(masks) > 0:
|
| 197 |
+
# Convert to numpy
|
| 198 |
+
masks_np = masks.cpu().numpy() if torch.is_tensor(masks) else masks
|
| 199 |
+
boxes_np = boxes.cpu().numpy() if torch.is_tensor(boxes) else boxes
|
| 200 |
+
scores_np = scores.cpu().numpy() if torch.is_tensor(scores) else scores
|
| 201 |
+
|
| 202 |
+
# Filter by confidence
|
| 203 |
+
mask = scores_np >= threshold
|
| 204 |
+
|
| 205 |
+
if mask.any():
|
| 206 |
+
all_masks.append(masks_np[mask])
|
| 207 |
+
all_boxes.append(boxes_np[mask])
|
| 208 |
+
all_scores.append(scores_np[mask])
|
| 209 |
+
all_prompt_indices.append(
|
| 210 |
+
np.full(mask.sum(), prompt_idx, dtype=np.int32)
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# Combine results
|
| 214 |
+
if all_masks:
|
| 215 |
+
combined_masks = np.concatenate(all_masks, axis=0)
|
| 216 |
+
combined_boxes = np.concatenate(all_boxes, axis=0)
|
| 217 |
+
combined_scores = np.concatenate(all_scores, axis=0)
|
| 218 |
+
combined_indices = np.concatenate(all_prompt_indices, axis=0)
|
| 219 |
+
else:
|
| 220 |
+
# Return empty results
|
| 221 |
+
h, w = np.array(image).shape[:2]
|
| 222 |
+
combined_masks = np.zeros((0, h, w), dtype=bool)
|
| 223 |
+
combined_boxes = np.zeros((0, 4), dtype=np.float32)
|
| 224 |
+
combined_scores = np.zeros((0,), dtype=np.float32)
|
| 225 |
+
combined_indices = np.zeros((0,), dtype=np.int32)
|
| 226 |
+
|
| 227 |
+
return SegmentationResult(
|
| 228 |
+
masks=combined_masks,
|
| 229 |
+
boxes=combined_boxes,
|
| 230 |
+
scores=combined_scores,
|
| 231 |
+
prompts=text_prompts,
|
| 232 |
+
prompt_indices=combined_indices
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
def segment_disease_regions(
|
| 236 |
+
self,
|
| 237 |
+
image: Union[str, Path, Image.Image, np.ndarray],
|
| 238 |
+
profile: str = "standard"
|
| 239 |
+
) -> SegmentationResult:
|
| 240 |
+
"""
|
| 241 |
+
Segment disease regions using predefined prompt profiles.
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
image: Input image
|
| 245 |
+
profile: Analysis profile ("quick_scan", "standard", "comprehensive", "pest_focused")
|
| 246 |
+
|
| 247 |
+
Returns:
|
| 248 |
+
SegmentationResult for the specified analysis profile
|
| 249 |
+
"""
|
| 250 |
+
profiles = self.config.get("analysis_profiles", {})
|
| 251 |
+
|
| 252 |
+
if profile not in profiles:
|
| 253 |
+
available = list(profiles.keys())
|
| 254 |
+
raise ValueError(f"Profile '{profile}' not found. Available: {available}")
|
| 255 |
+
|
| 256 |
+
prompts = profiles[profile]["prompts"]
|
| 257 |
+
logger.info(f"Using profile '{profile}' with {len(prompts)} prompts")
|
| 258 |
+
|
| 259 |
+
return self.segment_with_concepts(image, prompts)
|
| 260 |
+
|
| 261 |
+
def calculate_affected_area(
|
| 262 |
+
self,
|
| 263 |
+
result: SegmentationResult,
|
| 264 |
+
healthy_prompt_idx: Optional[int] = None
|
| 265 |
+
) -> Dict[str, float]:
|
| 266 |
+
"""
|
| 267 |
+
Calculate the percentage of affected area.
|
| 268 |
+
|
| 269 |
+
Args:
|
| 270 |
+
result: Segmentation result
|
| 271 |
+
healthy_prompt_idx: Index of the "healthy" prompt for comparison
|
| 272 |
+
|
| 273 |
+
Returns:
|
| 274 |
+
Dictionary with area statistics
|
| 275 |
+
"""
|
| 276 |
+
if len(result.masks) == 0:
|
| 277 |
+
return {"total_affected_percent": 0.0, "per_symptom": {}}
|
| 278 |
+
|
| 279 |
+
# Total image area
|
| 280 |
+
h, w = result.masks[0].shape
|
| 281 |
+
total_area = h * w
|
| 282 |
+
|
| 283 |
+
# Calculate areas per prompt
|
| 284 |
+
per_symptom = {}
|
| 285 |
+
total_diseased_area = 0
|
| 286 |
+
healthy_area = 0
|
| 287 |
+
|
| 288 |
+
for prompt_idx, prompt in enumerate(result.prompts):
|
| 289 |
+
mask_indices = result.prompt_indices == prompt_idx
|
| 290 |
+
if mask_indices.any():
|
| 291 |
+
combined_mask = result.masks[mask_indices].any(axis=0)
|
| 292 |
+
area = combined_mask.sum()
|
| 293 |
+
percent = (area / total_area) * 100
|
| 294 |
+
per_symptom[prompt] = percent
|
| 295 |
+
|
| 296 |
+
if healthy_prompt_idx is not None and prompt_idx == healthy_prompt_idx:
|
| 297 |
+
healthy_area = area
|
| 298 |
+
else:
|
| 299 |
+
total_diseased_area += area
|
| 300 |
+
|
| 301 |
+
# Calculate total affected (excluding overlaps approximation)
|
| 302 |
+
all_diseased_mask = np.zeros((h, w), dtype=bool)
|
| 303 |
+
for prompt_idx, prompt in enumerate(result.prompts):
|
| 304 |
+
if healthy_prompt_idx is None or prompt_idx != healthy_prompt_idx:
|
| 305 |
+
mask_indices = result.prompt_indices == prompt_idx
|
| 306 |
+
if mask_indices.any():
|
| 307 |
+
all_diseased_mask |= result.masks[mask_indices].any(axis=0)
|
| 308 |
+
|
| 309 |
+
affected_percent = (all_diseased_mask.sum() / total_area) * 100
|
| 310 |
+
|
| 311 |
+
return {
|
| 312 |
+
"total_affected_percent": affected_percent,
|
| 313 |
+
"per_symptom": per_symptom,
|
| 314 |
+
"healthy_percent": (healthy_area / total_area) * 100 if healthy_prompt_idx else None
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
def get_disease_prompts(self, category: str = "all") -> List[str]:
|
| 318 |
+
"""
|
| 319 |
+
Get predefined disease detection prompts.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
category: Prompt category ("general", "fungal", "bacterial",
|
| 323 |
+
"viral", "nutrient", "pest", or "all")
|
| 324 |
+
|
| 325 |
+
Returns:
|
| 326 |
+
List of prompts for the specified category
|
| 327 |
+
"""
|
| 328 |
+
prompts_config = self.config.get("prompts", {})
|
| 329 |
+
|
| 330 |
+
if category == "all":
|
| 331 |
+
all_prompts = []
|
| 332 |
+
for cat_prompts in prompts_config.values():
|
| 333 |
+
all_prompts.extend(cat_prompts)
|
| 334 |
+
return all_prompts
|
| 335 |
+
elif category in prompts_config:
|
| 336 |
+
return prompts_config[category]
|
| 337 |
+
else:
|
| 338 |
+
available = list(prompts_config.keys()) + ["all"]
|
| 339 |
+
raise ValueError(f"Category '{category}' not found. Available: {available}")
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
class MockSAM3Segmenter(SAM3Segmenter):
|
| 343 |
+
"""
|
| 344 |
+
Color-based segmentation for plant disease detection.
|
| 345 |
+
|
| 346 |
+
Analyzes actual image colors to detect disease symptoms:
|
| 347 |
+
- Green regions = healthy tissue
|
| 348 |
+
- Brown/yellow/spotted regions = potential disease
|
| 349 |
+
|
| 350 |
+
Uses scipy.ndimage for blob detection on non-green regions.
|
| 351 |
+
"""
|
| 352 |
+
|
| 353 |
+
def load_model(self):
|
| 354 |
+
"""Skip model loading for color-based analysis."""
|
| 355 |
+
logger.info("Using MockSAM3Segmenter (color-based analysis)")
|
| 356 |
+
self.model = "color_analysis"
|
| 357 |
+
self.processor = "color_analysis"
|
| 358 |
+
|
| 359 |
+
def _compute_hsv(self, img_array: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 360 |
+
"""Convert RGB image to HSV channels."""
|
| 361 |
+
r, g, b = img_array[:,:,0], img_array[:,:,1], img_array[:,:,2]
|
| 362 |
+
|
| 363 |
+
rgb_max = np.maximum(np.maximum(r, g), b)
|
| 364 |
+
rgb_min = np.minimum(np.minimum(r, g), b)
|
| 365 |
+
delta = (rgb_max - rgb_min).astype(np.float32) + 1e-10
|
| 366 |
+
|
| 367 |
+
# Value
|
| 368 |
+
v = rgb_max
|
| 369 |
+
|
| 370 |
+
# Saturation
|
| 371 |
+
s = np.where(rgb_max > 0, (delta / (rgb_max.astype(np.float32) + 1e-10)) * 255, 0).astype(np.uint8)
|
| 372 |
+
|
| 373 |
+
# Hue
|
| 374 |
+
h_channel = np.zeros_like(r, dtype=np.float32)
|
| 375 |
+
|
| 376 |
+
mask_r = (rgb_max == r)
|
| 377 |
+
h_channel[mask_r] = 60 * (((g[mask_r].astype(np.float32) - b[mask_r]) / delta[mask_r]) % 6)
|
| 378 |
+
|
| 379 |
+
mask_g = (rgb_max == g) & ~mask_r
|
| 380 |
+
h_channel[mask_g] = 60 * (((b[mask_g].astype(np.float32) - r[mask_g]) / delta[mask_g]) + 2)
|
| 381 |
+
|
| 382 |
+
mask_b = (rgb_max == b) & ~mask_r & ~mask_g
|
| 383 |
+
h_channel[mask_b] = 60 * (((r[mask_b].astype(np.float32) - g[mask_b]) / delta[mask_b]) + 4)
|
| 384 |
+
|
| 385 |
+
h_channel = (h_channel / 2).astype(np.uint8) # 0-180 range
|
| 386 |
+
|
| 387 |
+
return h_channel, s, v
|
| 388 |
+
|
| 389 |
+
def _segment_leaf(self, img_array: np.ndarray, h: np.ndarray, s: np.ndarray, v: np.ndarray) -> np.ndarray:
|
| 390 |
+
"""
|
| 391 |
+
Segment the leaf/plant tissue from the background.
|
| 392 |
+
|
| 393 |
+
Uses color analysis to find plant material:
|
| 394 |
+
- High saturation (plants are colorful, backgrounds are often gray/neutral)
|
| 395 |
+
- Green to yellow-brown hue range (plant tissue colors)
|
| 396 |
+
- Reasonable brightness
|
| 397 |
+
|
| 398 |
+
Returns the largest connected region as the leaf mask.
|
| 399 |
+
"""
|
| 400 |
+
from scipy import ndimage
|
| 401 |
+
|
| 402 |
+
img_h, img_w = img_array.shape[:2]
|
| 403 |
+
|
| 404 |
+
# Plant tissue typically has:
|
| 405 |
+
# 1. Good saturation (colorful, not gray)
|
| 406 |
+
# 2. Hue in plant range: green (35-85) OR yellow/brown diseased (10-45)
|
| 407 |
+
# 3. Reasonable brightness
|
| 408 |
+
|
| 409 |
+
# Broad plant color range (green to brown/yellow)
|
| 410 |
+
plant_hue_mask = (
|
| 411 |
+
((h >= 15) & (h <= 90)) | # Green to yellow-green
|
| 412 |
+
((h >= 5) & (h <= 30)) # Brown/orange (diseased tissue)
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
# Plant tissue has good saturation and brightness
|
| 416 |
+
plant_saturation_mask = (s >= 25) # Saturated (not gray)
|
| 417 |
+
plant_brightness_mask = (v >= 30) & (v <= 250) # Not too dark, not blown out
|
| 418 |
+
|
| 419 |
+
# Combine criteria
|
| 420 |
+
potential_leaf = plant_hue_mask & plant_saturation_mask & plant_brightness_mask
|
| 421 |
+
|
| 422 |
+
# Also include high saturation areas regardless of hue (catches more plant tissue)
|
| 423 |
+
high_saturation = (s >= 50) & plant_brightness_mask
|
| 424 |
+
potential_leaf = potential_leaf | high_saturation
|
| 425 |
+
|
| 426 |
+
# Clean up with morphological operations
|
| 427 |
+
potential_leaf = ndimage.binary_closing(potential_leaf, iterations=3)
|
| 428 |
+
potential_leaf = ndimage.binary_opening(potential_leaf, iterations=2)
|
| 429 |
+
potential_leaf = ndimage.binary_fill_holes(potential_leaf)
|
| 430 |
+
|
| 431 |
+
# Find the largest connected component (main leaf)
|
| 432 |
+
labeled, num_features = ndimage.label(potential_leaf)
|
| 433 |
+
|
| 434 |
+
if num_features == 0:
|
| 435 |
+
# No leaf found - return full image as fallback
|
| 436 |
+
logger.warning("No leaf detected, using full image")
|
| 437 |
+
return np.ones((img_h, img_w), dtype=bool)
|
| 438 |
+
|
| 439 |
+
# Find largest component
|
| 440 |
+
component_sizes = ndimage.sum(potential_leaf, labeled, range(1, num_features + 1))
|
| 441 |
+
largest_idx = np.argmax(component_sizes) + 1
|
| 442 |
+
leaf_mask = (labeled == largest_idx)
|
| 443 |
+
|
| 444 |
+
# Leaf should cover at least 10% of image to be valid
|
| 445 |
+
leaf_coverage = leaf_mask.sum() / (img_h * img_w)
|
| 446 |
+
if leaf_coverage < 0.10:
|
| 447 |
+
logger.warning(f"Leaf too small ({leaf_coverage:.1%}), using full image")
|
| 448 |
+
return np.ones((img_h, img_w), dtype=bool)
|
| 449 |
+
|
| 450 |
+
logger.debug(f"Leaf segmented: {leaf_coverage:.1%} of image")
|
| 451 |
+
return leaf_mask
|
| 452 |
+
|
| 453 |
+
def _detect_disease_regions(self, image: Image.Image) -> Tuple[np.ndarray, List[Dict]]:
|
| 454 |
+
"""
|
| 455 |
+
Detect disease regions based on color analysis.
|
| 456 |
+
|
| 457 |
+
First segments the leaf from background, then analyzes only
|
| 458 |
+
the leaf area for disease symptoms.
|
| 459 |
+
|
| 460 |
+
Returns:
|
| 461 |
+
Tuple of (binary mask of all abnormal regions, list of blob info dicts)
|
| 462 |
+
"""
|
| 463 |
+
from scipy import ndimage
|
| 464 |
+
|
| 465 |
+
img_array = np.array(image)
|
| 466 |
+
img_h, img_w = img_array.shape[:2]
|
| 467 |
+
|
| 468 |
+
# Compute HSV
|
| 469 |
+
h_channel, s, v = self._compute_hsv(img_array)
|
| 470 |
+
|
| 471 |
+
# Step 1: Segment the leaf from background
|
| 472 |
+
leaf_mask = self._segment_leaf(img_array, h_channel, s, v)
|
| 473 |
+
leaf_area = leaf_mask.sum()
|
| 474 |
+
|
| 475 |
+
if leaf_area == 0:
|
| 476 |
+
return np.zeros((img_h, img_w), dtype=bool), []
|
| 477 |
+
|
| 478 |
+
# Step 2: Within the leaf, find healthy green regions
|
| 479 |
+
green_mask = (
|
| 480 |
+
(h_channel >= 35) & (h_channel <= 85) & # Green hue
|
| 481 |
+
(s >= 30) & # Saturated
|
| 482 |
+
(v >= 30) & # Not too dark
|
| 483 |
+
leaf_mask # Only within leaf
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
green_area = green_mask.sum()
|
| 487 |
+
green_ratio = green_area / leaf_area
|
| 488 |
+
|
| 489 |
+
logger.debug(f"Within leaf - Green: {green_ratio:.1%}, Leaf area: {leaf_area}px")
|
| 490 |
+
|
| 491 |
+
# Step 3: Define disease colors (only within leaf!)
|
| 492 |
+
# Brown spots: low hue, moderate saturation
|
| 493 |
+
brown_mask = (
|
| 494 |
+
(h_channel >= 5) & (h_channel <= 25) &
|
| 495 |
+
(s >= 30) &
|
| 496 |
+
(v >= 40) & (v <= 200) &
|
| 497 |
+
leaf_mask
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
# Yellow/chlorosis: yellow hue, high saturation
|
| 501 |
+
yellow_mask = (
|
| 502 |
+
(h_channel >= 20) & (h_channel <= 40) &
|
| 503 |
+
(s >= 40) &
|
| 504 |
+
(v >= 80) &
|
| 505 |
+
leaf_mask
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
# Necrotic dark spots (within leaf only)
|
| 509 |
+
dark_spots = (
|
| 510 |
+
(v <= 60) &
|
| 511 |
+
(s >= 15) & # Some color, not pure black
|
| 512 |
+
leaf_mask
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
# White spots (powdery mildew) - within leaf
|
| 516 |
+
white_spots = (
|
| 517 |
+
(v >= 200) &
|
| 518 |
+
(s <= 40) &
|
| 519 |
+
leaf_mask
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
# Combine abnormal regions
|
| 523 |
+
abnormal_mask = (brown_mask | yellow_mask | dark_spots | white_spots)
|
| 524 |
+
abnormal_area = abnormal_mask.sum()
|
| 525 |
+
|
| 526 |
+
logger.debug(f"Abnormal pixels within leaf: {abnormal_area} ({abnormal_area/leaf_area:.1%} of leaf)")
|
| 527 |
+
|
| 528 |
+
# If mostly green (>80% of leaf is green), consider healthy
|
| 529 |
+
if green_ratio > 0.80 and abnormal_area < leaf_area * 0.05:
|
| 530 |
+
logger.info(f"Leaf appears healthy ({green_ratio:.0%} green)")
|
| 531 |
+
return np.zeros((img_h, img_w), dtype=bool), []
|
| 532 |
+
|
| 533 |
+
# If very little abnormal tissue, also healthy
|
| 534 |
+
if abnormal_area < leaf_area * 0.02:
|
| 535 |
+
logger.info("Minimal abnormal tissue detected - healthy")
|
| 536 |
+
return np.zeros((img_h, img_w), dtype=bool), []
|
| 537 |
+
|
| 538 |
+
# Clean up the abnormal mask
|
| 539 |
+
abnormal_mask = ndimage.binary_opening(abnormal_mask, iterations=1)
|
| 540 |
+
abnormal_mask = ndimage.binary_closing(abnormal_mask, iterations=2)
|
| 541 |
+
|
| 542 |
+
# Label connected components
|
| 543 |
+
labeled_array, num_features = ndimage.label(abnormal_mask)
|
| 544 |
+
|
| 545 |
+
# Filter blobs by size (relative to leaf, not image)
|
| 546 |
+
min_blob_area = max(50, leaf_area * 0.005) # At least 0.5% of leaf
|
| 547 |
+
max_blob_area = leaf_area * 0.6 # At most 60% of leaf
|
| 548 |
+
|
| 549 |
+
blobs = []
|
| 550 |
+
for label_idx in range(1, num_features + 1):
|
| 551 |
+
blob_mask = (labeled_array == label_idx)
|
| 552 |
+
blob_area = blob_mask.sum()
|
| 553 |
+
|
| 554 |
+
if min_blob_area <= blob_area <= max_blob_area:
|
| 555 |
+
# Get bounding box
|
| 556 |
+
rows = np.any(blob_mask, axis=1)
|
| 557 |
+
cols = np.any(blob_mask, axis=0)
|
| 558 |
+
y_min, y_max = np.where(rows)[0][[0, -1]]
|
| 559 |
+
x_min, x_max = np.where(cols)[0][[0, -1]]
|
| 560 |
+
|
| 561 |
+
# Calculate confidence based on color
|
| 562 |
+
blob_region = img_array[blob_mask]
|
| 563 |
+
avg_color = blob_region.mean(axis=0)
|
| 564 |
+
|
| 565 |
+
r_ratio = avg_color[0] / 255
|
| 566 |
+
g_ratio = avg_color[1] / 255
|
| 567 |
+
b_ratio = avg_color[2] / 255
|
| 568 |
+
|
| 569 |
+
# Score: more brown/yellow = higher confidence
|
| 570 |
+
color_score = r_ratio - 0.5 * g_ratio + 0.3 * (1 - b_ratio)
|
| 571 |
+
color_score = np.clip(color_score, 0, 1)
|
| 572 |
+
|
| 573 |
+
# Area score relative to leaf
|
| 574 |
+
area_ratio = blob_area / leaf_area
|
| 575 |
+
area_score = min(1.0, area_ratio * 10)
|
| 576 |
+
|
| 577 |
+
confidence = 0.4 + 0.4 * color_score + 0.2 * area_score
|
| 578 |
+
confidence = np.clip(confidence, 0.3, 0.95)
|
| 579 |
+
|
| 580 |
+
blobs.append({
|
| 581 |
+
'mask': blob_mask,
|
| 582 |
+
'bbox': [x_min, y_min, x_max, y_max],
|
| 583 |
+
'area': blob_area,
|
| 584 |
+
'confidence': float(confidence)
|
| 585 |
+
})
|
| 586 |
+
|
| 587 |
+
return abnormal_mask, blobs
|
| 588 |
+
|
| 589 |
+
def segment_with_concepts(
|
| 590 |
+
self,
|
| 591 |
+
image: Union[str, Path, Image.Image, np.ndarray],
|
| 592 |
+
text_prompts: List[str],
|
| 593 |
+
confidence_threshold: Optional[float] = None
|
| 594 |
+
) -> SegmentationResult:
|
| 595 |
+
"""
|
| 596 |
+
Segment disease regions based on color analysis.
|
| 597 |
+
|
| 598 |
+
Analyzes the image colors to detect abnormal (non-green) regions
|
| 599 |
+
that may indicate disease. Returns empty results for healthy images.
|
| 600 |
+
"""
|
| 601 |
+
# Load image
|
| 602 |
+
if isinstance(image, (str, Path)):
|
| 603 |
+
image = Image.open(image).convert("RGB")
|
| 604 |
+
elif isinstance(image, np.ndarray):
|
| 605 |
+
image = Image.fromarray(image)
|
| 606 |
+
|
| 607 |
+
w, h = image.size
|
| 608 |
+
threshold = confidence_threshold or self.config["inference"]["confidence_threshold"]
|
| 609 |
+
|
| 610 |
+
# Detect disease regions based on color
|
| 611 |
+
abnormal_mask, blobs = self._detect_disease_regions(image)
|
| 612 |
+
|
| 613 |
+
# Filter by confidence threshold
|
| 614 |
+
blobs = [b for b in blobs if b['confidence'] >= threshold]
|
| 615 |
+
|
| 616 |
+
if not blobs:
|
| 617 |
+
logger.info("No disease regions detected (healthy image)")
|
| 618 |
+
return SegmentationResult(
|
| 619 |
+
masks=np.zeros((0, h, w), dtype=bool),
|
| 620 |
+
boxes=np.zeros((0, 4), dtype=np.float32),
|
| 621 |
+
scores=np.zeros((0,), dtype=np.float32),
|
| 622 |
+
prompts=text_prompts,
|
| 623 |
+
prompt_indices=np.zeros((0,), dtype=np.int32)
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
# Convert blobs to arrays
|
| 627 |
+
num_detections = len(blobs)
|
| 628 |
+
masks = np.zeros((num_detections, h, w), dtype=bool)
|
| 629 |
+
boxes = np.zeros((num_detections, 4), dtype=np.float32)
|
| 630 |
+
scores = np.zeros(num_detections, dtype=np.float32)
|
| 631 |
+
|
| 632 |
+
# Assign detections to first disease-related prompt (skip "healthy" prompts)
|
| 633 |
+
disease_prompt_idx = 0
|
| 634 |
+
for idx, prompt in enumerate(text_prompts):
|
| 635 |
+
if "healthy" not in prompt.lower():
|
| 636 |
+
disease_prompt_idx = idx
|
| 637 |
+
break
|
| 638 |
+
|
| 639 |
+
prompt_indices = np.full(num_detections, disease_prompt_idx, dtype=np.int32)
|
| 640 |
+
|
| 641 |
+
for i, blob in enumerate(blobs):
|
| 642 |
+
masks[i] = blob['mask']
|
| 643 |
+
boxes[i] = blob['bbox']
|
| 644 |
+
scores[i] = blob['confidence']
|
| 645 |
+
|
| 646 |
+
logger.info(f"Detected {num_detections} disease region(s)")
|
| 647 |
+
|
| 648 |
+
return SegmentationResult(
|
| 649 |
+
masks=masks,
|
| 650 |
+
boxes=boxes,
|
| 651 |
+
scores=scores,
|
| 652 |
+
prompts=text_prompts,
|
| 653 |
+
prompt_indices=prompt_indices
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
class RFDETRSegmenter(SAM3Segmenter):
|
| 658 |
+
"""
|
| 659 |
+
RF-DETR based object detection for plant disease detection.
|
| 660 |
+
|
| 661 |
+
Uses a trained RF-DETR model (DETR-based detector) instead of SAM 3.
|
| 662 |
+
RF-DETR is trained on annotated plant disease datasets with bounding boxes.
|
| 663 |
+
|
| 664 |
+
Example:
|
| 665 |
+
>>> segmenter = RFDETRSegmenter("models/rfdetr/best.pt")
|
| 666 |
+
>>> result = segmenter.segment_with_concepts(image, ["disease"])
|
| 667 |
+
>>> print(f"Found {len(result.masks)} disease regions")
|
| 668 |
+
"""
|
| 669 |
+
|
| 670 |
+
def __init__(
|
| 671 |
+
self,
|
| 672 |
+
checkpoint_path: str = "models/rfdetr/best.pt",
|
| 673 |
+
config_path: str = "configs/sam3_config.yaml",
|
| 674 |
+
device: Optional[str] = None,
|
| 675 |
+
model_size: str = "medium"
|
| 676 |
+
):
|
| 677 |
+
"""
|
| 678 |
+
Initialize RF-DETR segmenter.
|
| 679 |
+
|
| 680 |
+
Args:
|
| 681 |
+
checkpoint_path: Path to trained RF-DETR checkpoint
|
| 682 |
+
config_path: Path to configuration file
|
| 683 |
+
device: Device to use (auto-detected if None)
|
| 684 |
+
model_size: RF-DETR model size (nano, small, medium, base)
|
| 685 |
+
"""
|
| 686 |
+
super().__init__(checkpoint_path, config_path, device)
|
| 687 |
+
self.model_size = model_size
|
| 688 |
+
self.class_names = ["Pestalotiopsis"] # Default class, updated after loading
|
| 689 |
+
|
| 690 |
+
def load_model(self):
|
| 691 |
+
"""Load RF-DETR model."""
|
| 692 |
+
if self.model is not None:
|
| 693 |
+
return
|
| 694 |
+
|
| 695 |
+
logger.info(f"Loading RF-DETR {self.model_size} model...")
|
| 696 |
+
|
| 697 |
+
try:
|
| 698 |
+
# Import RF-DETR
|
| 699 |
+
if self.model_size == "nano":
|
| 700 |
+
from rfdetr import RFDETRNano as RFDETRModel
|
| 701 |
+
elif self.model_size == "small":
|
| 702 |
+
from rfdetr import RFDETRSmall as RFDETRModel
|
| 703 |
+
elif self.model_size == "medium":
|
| 704 |
+
from rfdetr import RFDETRMedium as RFDETRModel
|
| 705 |
+
else:
|
| 706 |
+
from rfdetr import RFDETRBase as RFDETRModel
|
| 707 |
+
|
| 708 |
+
# Load model with custom weights if available
|
| 709 |
+
checkpoint = Path(self.checkpoint_path)
|
| 710 |
+
if checkpoint.exists():
|
| 711 |
+
logger.info(f"Loading custom weights from {checkpoint}")
|
| 712 |
+
self.model = RFDETRModel(pretrain_weights=str(checkpoint))
|
| 713 |
+
else:
|
| 714 |
+
logger.warning(f"Checkpoint not found at {checkpoint}, using pretrained weights")
|
| 715 |
+
self.model = RFDETRModel()
|
| 716 |
+
|
| 717 |
+
logger.info("RF-DETR model loaded successfully")
|
| 718 |
+
|
| 719 |
+
except ImportError as e:
|
| 720 |
+
logger.error(f"RF-DETR not installed: {e}")
|
| 721 |
+
logger.error("Install with: pip install rfdetr")
|
| 722 |
+
raise
|
| 723 |
+
|
| 724 |
+
def segment_with_concepts(
|
| 725 |
+
self,
|
| 726 |
+
image: Union[str, Path, Image.Image, np.ndarray],
|
| 727 |
+
text_prompts: List[str],
|
| 728 |
+
confidence_threshold: Optional[float] = None
|
| 729 |
+
) -> SegmentationResult:
|
| 730 |
+
"""
|
| 731 |
+
Detect disease regions using RF-DETR.
|
| 732 |
+
|
| 733 |
+
Note: RF-DETR is class-based (not prompt-based), so text_prompts
|
| 734 |
+
are ignored. The model detects all trained disease classes.
|
| 735 |
+
|
| 736 |
+
Args:
|
| 737 |
+
image: Input image
|
| 738 |
+
text_prompts: Ignored (RF-DETR uses trained classes)
|
| 739 |
+
confidence_threshold: Detection confidence threshold
|
| 740 |
+
|
| 741 |
+
Returns:
|
| 742 |
+
SegmentationResult with detected disease regions
|
| 743 |
+
"""
|
| 744 |
+
self.load_model()
|
| 745 |
+
|
| 746 |
+
# Load image
|
| 747 |
+
if isinstance(image, (str, Path)):
|
| 748 |
+
pil_image = Image.open(image).convert("RGB")
|
| 749 |
+
elif isinstance(image, np.ndarray):
|
| 750 |
+
pil_image = Image.fromarray(image)
|
| 751 |
+
else:
|
| 752 |
+
pil_image = image
|
| 753 |
+
|
| 754 |
+
w, h = pil_image.size
|
| 755 |
+
threshold = confidence_threshold or self.config["inference"]["confidence_threshold"]
|
| 756 |
+
|
| 757 |
+
# Run RF-DETR detection
|
| 758 |
+
try:
|
| 759 |
+
detections = self.model.predict(pil_image, threshold=threshold)
|
| 760 |
+
except Exception as e:
|
| 761 |
+
logger.error(f"RF-DETR prediction failed: {e}")
|
| 762 |
+
return SegmentationResult(
|
| 763 |
+
masks=np.zeros((0, h, w), dtype=bool),
|
| 764 |
+
boxes=np.zeros((0, 4), dtype=np.float32),
|
| 765 |
+
scores=np.zeros((0,), dtype=np.float32),
|
| 766 |
+
prompts=text_prompts,
|
| 767 |
+
prompt_indices=np.zeros((0,), dtype=np.int32)
|
| 768 |
+
)
|
| 769 |
+
|
| 770 |
+
# Extract detections from supervision Detections object
|
| 771 |
+
num_detections = len(detections)
|
| 772 |
+
|
| 773 |
+
if num_detections == 0:
|
| 774 |
+
logger.info("No disease regions detected")
|
| 775 |
+
return SegmentationResult(
|
| 776 |
+
masks=np.zeros((0, h, w), dtype=bool),
|
| 777 |
+
boxes=np.zeros((0, 4), dtype=np.float32),
|
| 778 |
+
scores=np.zeros((0,), dtype=np.float32),
|
| 779 |
+
prompts=text_prompts,
|
| 780 |
+
prompt_indices=np.zeros((0,), dtype=np.int32)
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
# Get bounding boxes and scores
|
| 784 |
+
boxes = detections.xyxy.astype(np.float32) # [x1, y1, x2, y2]
|
| 785 |
+
scores = detections.confidence.astype(np.float32)
|
| 786 |
+
|
| 787 |
+
# Create masks from bounding boxes (RF-DETR gives boxes, not masks)
|
| 788 |
+
masks = np.zeros((num_detections, h, w), dtype=bool)
|
| 789 |
+
for i, box in enumerate(boxes):
|
| 790 |
+
x1, y1, x2, y2 = box.astype(int)
|
| 791 |
+
x1, y1 = max(0, x1), max(0, y1)
|
| 792 |
+
x2, y2 = min(w, x2), min(h, y2)
|
| 793 |
+
masks[i, y1:y2, x1:x2] = True
|
| 794 |
+
|
| 795 |
+
# Assign to first disease prompt
|
| 796 |
+
disease_prompt_idx = 0
|
| 797 |
+
for idx, prompt in enumerate(text_prompts):
|
| 798 |
+
if "healthy" not in prompt.lower():
|
| 799 |
+
disease_prompt_idx = idx
|
| 800 |
+
break
|
| 801 |
+
|
| 802 |
+
prompt_indices = np.full(num_detections, disease_prompt_idx, dtype=np.int32)
|
| 803 |
+
|
| 804 |
+
logger.info(f"RF-DETR detected {num_detections} disease region(s)")
|
| 805 |
+
|
| 806 |
+
return SegmentationResult(
|
| 807 |
+
masks=masks,
|
| 808 |
+
boxes=boxes,
|
| 809 |
+
scores=scores,
|
| 810 |
+
prompts=text_prompts,
|
| 811 |
+
prompt_indices=prompt_indices
|
| 812 |
+
)
|
| 813 |
+
|
| 814 |
+
|
| 815 |
+
def create_segmenter(
|
| 816 |
+
checkpoint_path: str = "models/sam3/sam3.pt",
|
| 817 |
+
config_path: str = "configs/sam3_config.yaml",
|
| 818 |
+
use_mock: bool = False,
|
| 819 |
+
use_rfdetr: bool = False,
|
| 820 |
+
rfdetr_checkpoint: str = "models/rfdetr/best.pt",
|
| 821 |
+
rfdetr_model_size: str = "medium"
|
| 822 |
+
) -> SAM3Segmenter:
|
| 823 |
+
"""
|
| 824 |
+
Factory function to create appropriate segmenter.
|
| 825 |
+
|
| 826 |
+
Args:
|
| 827 |
+
checkpoint_path: Path to SAM 3 checkpoint
|
| 828 |
+
config_path: Path to configuration
|
| 829 |
+
use_mock: If True, use color-based mock segmenter
|
| 830 |
+
use_rfdetr: If True, use RF-DETR detector
|
| 831 |
+
rfdetr_checkpoint: Path to RF-DETR checkpoint
|
| 832 |
+
rfdetr_model_size: RF-DETR model size (nano, small, medium, base)
|
| 833 |
+
|
| 834 |
+
Returns:
|
| 835 |
+
SAM3Segmenter, MockSAM3Segmenter, or RFDETRSegmenter instance
|
| 836 |
+
"""
|
| 837 |
+
if use_rfdetr:
|
| 838 |
+
return RFDETRSegmenter(
|
| 839 |
+
checkpoint_path=rfdetr_checkpoint,
|
| 840 |
+
config_path=config_path,
|
| 841 |
+
model_size=rfdetr_model_size
|
| 842 |
+
)
|
| 843 |
+
elif use_mock:
|
| 844 |
+
return MockSAM3Segmenter(checkpoint_path, config_path)
|
| 845 |
+
else:
|
| 846 |
+
return SAM3Segmenter(checkpoint_path, config_path)
|
| 847 |
+
|
| 848 |
+
|
| 849 |
+
if __name__ == "__main__":
|
| 850 |
+
# Quick test with mock
|
| 851 |
+
segmenter = create_segmenter(use_mock=True)
|
| 852 |
+
|
| 853 |
+
# Create a test image
|
| 854 |
+
test_image = Image.new("RGB", (640, 480), color=(34, 139, 34)) # Forest green
|
| 855 |
+
|
| 856 |
+
prompts = ["diseased leaf", "brown spots", "healthy tissue"]
|
| 857 |
+
result = segmenter.segment_with_concepts(test_image, prompts)
|
| 858 |
+
|
| 859 |
+
print(f"Found {len(result.masks)} regions")
|
| 860 |
+
print(f"Scores: {result.scores}")
|
| 861 |
+
print(f"Prompts used: {[result.prompts[i] for i in result.prompt_indices]}")
|
| 862 |
+
|
| 863 |
+
areas = segmenter.calculate_affected_area(result)
|
| 864 |
+
print(f"Affected area: {areas['total_affected_percent']:.1f}%")
|
src/severity_classifier.py
ADDED
|
@@ -0,0 +1,590 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Severity Classifier Module for CropDoctor-Semantic
|
| 3 |
+
===================================================
|
| 4 |
+
|
| 5 |
+
This module provides a CNN-based classifier to assess the severity
|
| 6 |
+
of plant diseases from segmented regions.
|
| 7 |
+
|
| 8 |
+
Severity Levels:
|
| 9 |
+
0 - Healthy: No disease symptoms
|
| 10 |
+
1 - Mild: <10% affected area, early stage
|
| 11 |
+
2 - Moderate: 10-30% affected, established infection
|
| 12 |
+
3 - Severe: >30% affected, critical intervention needed
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
from torch.utils.data import Dataset, DataLoader
|
| 19 |
+
import torchvision.transforms as transforms
|
| 20 |
+
from torchvision import models
|
| 21 |
+
import numpy as np
|
| 22 |
+
from PIL import Image
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from typing import Tuple, Dict, List, Optional, Union
|
| 25 |
+
from dataclasses import dataclass
|
| 26 |
+
import logging
|
| 27 |
+
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class SeverityPrediction:
|
| 33 |
+
"""Container for severity classification results."""
|
| 34 |
+
severity_level: int # 0-3
|
| 35 |
+
severity_label: str # "healthy", "mild", "moderate", "severe"
|
| 36 |
+
confidence: float # 0-1
|
| 37 |
+
probabilities: Dict[str, float] # Per-class probabilities
|
| 38 |
+
affected_area_percent: float # From mask analysis
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# Severity level mapping
|
| 42 |
+
SEVERITY_LABELS = {
|
| 43 |
+
0: "healthy",
|
| 44 |
+
1: "mild",
|
| 45 |
+
2: "moderate",
|
| 46 |
+
3: "severe"
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
SEVERITY_DESCRIPTIONS = {
|
| 50 |
+
0: "No disease symptoms detected. Plant appears healthy.",
|
| 51 |
+
1: "Early stage infection. Less than 10% of tissue affected. Preventive action recommended.",
|
| 52 |
+
2: "Established infection. 10-30% of tissue affected. Treatment required.",
|
| 53 |
+
3: "Severe infection. Over 30% of tissue affected. Urgent intervention needed."
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class SeverityClassifierCNN(nn.Module):
|
| 58 |
+
"""
|
| 59 |
+
CNN model for disease severity classification.
|
| 60 |
+
|
| 61 |
+
Architecture options:
|
| 62 |
+
- EfficientNet-B0 (lightweight, fast)
|
| 63 |
+
- ResNet-50 (balanced)
|
| 64 |
+
- ConvNeXt-Tiny (modern, accurate)
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(
|
| 68 |
+
self,
|
| 69 |
+
num_classes: int = 4,
|
| 70 |
+
backbone: str = "efficientnet_b0",
|
| 71 |
+
pretrained: bool = True,
|
| 72 |
+
dropout: float = 0.3
|
| 73 |
+
):
|
| 74 |
+
super().__init__()
|
| 75 |
+
|
| 76 |
+
self.num_classes = num_classes
|
| 77 |
+
self.backbone_name = backbone
|
| 78 |
+
|
| 79 |
+
# Load backbone
|
| 80 |
+
if backbone == "efficientnet_b0":
|
| 81 |
+
self.backbone = models.efficientnet_b0(
|
| 82 |
+
weights=models.EfficientNet_B0_Weights.DEFAULT if pretrained else None
|
| 83 |
+
)
|
| 84 |
+
in_features = self.backbone.classifier[1].in_features
|
| 85 |
+
self.backbone.classifier = nn.Sequential(
|
| 86 |
+
nn.Dropout(dropout),
|
| 87 |
+
nn.Linear(in_features, num_classes)
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
elif backbone == "resnet50":
|
| 91 |
+
self.backbone = models.resnet50(
|
| 92 |
+
weights=models.ResNet50_Weights.DEFAULT if pretrained else None
|
| 93 |
+
)
|
| 94 |
+
in_features = self.backbone.fc.in_features
|
| 95 |
+
self.backbone.fc = nn.Sequential(
|
| 96 |
+
nn.Dropout(dropout),
|
| 97 |
+
nn.Linear(in_features, num_classes)
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
elif backbone == "convnext_tiny":
|
| 101 |
+
self.backbone = models.convnext_tiny(
|
| 102 |
+
weights=models.ConvNeXt_Tiny_Weights.DEFAULT if pretrained else None
|
| 103 |
+
)
|
| 104 |
+
in_features = self.backbone.classifier[2].in_features
|
| 105 |
+
self.backbone.classifier = nn.Sequential(
|
| 106 |
+
nn.Flatten(1),
|
| 107 |
+
nn.LayerNorm(in_features),
|
| 108 |
+
nn.Dropout(dropout),
|
| 109 |
+
nn.Linear(in_features, num_classes)
|
| 110 |
+
)
|
| 111 |
+
else:
|
| 112 |
+
raise ValueError(f"Unknown backbone: {backbone}")
|
| 113 |
+
|
| 114 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 115 |
+
return self.backbone(x)
|
| 116 |
+
|
| 117 |
+
def predict(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 118 |
+
"""Get predictions with probabilities."""
|
| 119 |
+
logits = self.forward(x)
|
| 120 |
+
probs = F.softmax(logits, dim=1)
|
| 121 |
+
preds = torch.argmax(probs, dim=1)
|
| 122 |
+
return preds, probs
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class SeverityClassifier:
|
| 126 |
+
"""
|
| 127 |
+
High-level interface for severity classification.
|
| 128 |
+
|
| 129 |
+
Handles image preprocessing, model loading, and prediction formatting.
|
| 130 |
+
|
| 131 |
+
Example:
|
| 132 |
+
>>> classifier = SeverityClassifier("models/severity_classifier/best.pt")
|
| 133 |
+
>>> result = classifier.classify("diseased_leaf.jpg")
|
| 134 |
+
>>> print(f"Severity: {result.severity_label} ({result.confidence:.2f})")
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
def __init__(
|
| 138 |
+
self,
|
| 139 |
+
checkpoint_path: Optional[str] = None,
|
| 140 |
+
device: Optional[str] = None,
|
| 141 |
+
image_size: int = 224
|
| 142 |
+
):
|
| 143 |
+
"""
|
| 144 |
+
Initialize severity classifier.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
checkpoint_path: Path to trained model checkpoint
|
| 148 |
+
device: Device to use (auto-detected if None)
|
| 149 |
+
image_size: Input image size for the model
|
| 150 |
+
"""
|
| 151 |
+
# Set device
|
| 152 |
+
if device is None:
|
| 153 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 154 |
+
else:
|
| 155 |
+
self.device = device
|
| 156 |
+
|
| 157 |
+
self.image_size = image_size
|
| 158 |
+
self.checkpoint_path = checkpoint_path
|
| 159 |
+
|
| 160 |
+
# Initialize model
|
| 161 |
+
self.model = None
|
| 162 |
+
self._setup_transforms()
|
| 163 |
+
|
| 164 |
+
def _setup_transforms(self):
|
| 165 |
+
"""Setup image preprocessing transforms."""
|
| 166 |
+
# ImageNet normalization
|
| 167 |
+
self.transform = transforms.Compose([
|
| 168 |
+
transforms.Resize((self.image_size, self.image_size)),
|
| 169 |
+
transforms.ToTensor(),
|
| 170 |
+
transforms.Normalize(
|
| 171 |
+
mean=[0.485, 0.456, 0.406],
|
| 172 |
+
std=[0.229, 0.224, 0.225]
|
| 173 |
+
)
|
| 174 |
+
])
|
| 175 |
+
|
| 176 |
+
# Augmentation for training
|
| 177 |
+
self.train_transform = transforms.Compose([
|
| 178 |
+
transforms.RandomResizedCrop(self.image_size, scale=(0.8, 1.0)),
|
| 179 |
+
transforms.RandomHorizontalFlip(),
|
| 180 |
+
transforms.RandomVerticalFlip(),
|
| 181 |
+
transforms.RandomRotation(15),
|
| 182 |
+
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
|
| 183 |
+
transforms.ToTensor(),
|
| 184 |
+
transforms.Normalize(
|
| 185 |
+
mean=[0.485, 0.456, 0.406],
|
| 186 |
+
std=[0.229, 0.224, 0.225]
|
| 187 |
+
)
|
| 188 |
+
])
|
| 189 |
+
|
| 190 |
+
def load_model(self, backbone: str = "efficientnet_b0"):
|
| 191 |
+
"""Load or initialize the model."""
|
| 192 |
+
if self.model is not None:
|
| 193 |
+
return
|
| 194 |
+
|
| 195 |
+
self.model = SeverityClassifierCNN(
|
| 196 |
+
num_classes=4,
|
| 197 |
+
backbone=backbone,
|
| 198 |
+
pretrained=True
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
if self.checkpoint_path and Path(self.checkpoint_path).exists():
|
| 202 |
+
logger.info(f"Loading checkpoint from {self.checkpoint_path}")
|
| 203 |
+
checkpoint = torch.load(self.checkpoint_path, map_location=self.device, weights_only=False)
|
| 204 |
+
self.model.load_state_dict(checkpoint["model_state_dict"])
|
| 205 |
+
else:
|
| 206 |
+
logger.warning("No checkpoint loaded, using pretrained backbone only")
|
| 207 |
+
|
| 208 |
+
self.model.to(self.device)
|
| 209 |
+
self.model.eval()
|
| 210 |
+
|
| 211 |
+
def preprocess_image(
|
| 212 |
+
self,
|
| 213 |
+
image: Union[str, Path, Image.Image, np.ndarray]
|
| 214 |
+
) -> torch.Tensor:
|
| 215 |
+
"""Preprocess image for classification."""
|
| 216 |
+
if isinstance(image, (str, Path)):
|
| 217 |
+
image = Image.open(image).convert("RGB")
|
| 218 |
+
elif isinstance(image, np.ndarray):
|
| 219 |
+
image = Image.fromarray(image)
|
| 220 |
+
|
| 221 |
+
tensor = self.transform(image)
|
| 222 |
+
return tensor.unsqueeze(0) # Add batch dimension
|
| 223 |
+
|
| 224 |
+
def classify(
|
| 225 |
+
self,
|
| 226 |
+
image: Union[str, Path, Image.Image, np.ndarray],
|
| 227 |
+
mask: Optional[np.ndarray] = None
|
| 228 |
+
) -> SeverityPrediction:
|
| 229 |
+
"""
|
| 230 |
+
Classify disease severity in an image.
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
image: Input image (path, PIL Image, or numpy array)
|
| 234 |
+
mask: Optional binary mask of diseased region
|
| 235 |
+
|
| 236 |
+
Returns:
|
| 237 |
+
SeverityPrediction with severity level, confidence, and details
|
| 238 |
+
"""
|
| 239 |
+
self.load_model()
|
| 240 |
+
|
| 241 |
+
# Calculate affected area from mask
|
| 242 |
+
affected_percent = 0.0
|
| 243 |
+
if mask is not None:
|
| 244 |
+
affected_percent = (mask.sum() / mask.size) * 100
|
| 245 |
+
|
| 246 |
+
# Preprocess and predict
|
| 247 |
+
input_tensor = self.preprocess_image(image).to(self.device)
|
| 248 |
+
|
| 249 |
+
with torch.no_grad():
|
| 250 |
+
pred, probs = self.model.predict(input_tensor)
|
| 251 |
+
|
| 252 |
+
severity_level = pred.item()
|
| 253 |
+
confidence = probs[0, severity_level].item()
|
| 254 |
+
|
| 255 |
+
# Format probabilities
|
| 256 |
+
prob_dict = {
|
| 257 |
+
SEVERITY_LABELS[i]: probs[0, i].item()
|
| 258 |
+
for i in range(4)
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
return SeverityPrediction(
|
| 262 |
+
severity_level=severity_level,
|
| 263 |
+
severity_label=SEVERITY_LABELS[severity_level],
|
| 264 |
+
confidence=confidence,
|
| 265 |
+
probabilities=prob_dict,
|
| 266 |
+
affected_area_percent=affected_percent
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
def classify_region(
|
| 270 |
+
self,
|
| 271 |
+
image: Union[str, Path, Image.Image, np.ndarray],
|
| 272 |
+
mask: np.ndarray
|
| 273 |
+
) -> SeverityPrediction:
|
| 274 |
+
"""
|
| 275 |
+
Classify severity of a specific masked region.
|
| 276 |
+
|
| 277 |
+
Extracts the bounding box of the mask and classifies that region.
|
| 278 |
+
|
| 279 |
+
Args:
|
| 280 |
+
image: Full image
|
| 281 |
+
mask: Binary mask of region to classify
|
| 282 |
+
|
| 283 |
+
Returns:
|
| 284 |
+
SeverityPrediction for the masked region
|
| 285 |
+
"""
|
| 286 |
+
# Load image if needed
|
| 287 |
+
if isinstance(image, (str, Path)):
|
| 288 |
+
image = Image.open(image).convert("RGB")
|
| 289 |
+
elif isinstance(image, np.ndarray):
|
| 290 |
+
image = Image.fromarray(image)
|
| 291 |
+
|
| 292 |
+
img_array = np.array(image)
|
| 293 |
+
|
| 294 |
+
# Get bounding box from mask
|
| 295 |
+
rows = np.any(mask, axis=1)
|
| 296 |
+
cols = np.any(mask, axis=0)
|
| 297 |
+
|
| 298 |
+
if not rows.any() or not cols.any():
|
| 299 |
+
# Empty mask, return healthy
|
| 300 |
+
return SeverityPrediction(
|
| 301 |
+
severity_level=0,
|
| 302 |
+
severity_label="healthy",
|
| 303 |
+
confidence=1.0,
|
| 304 |
+
probabilities={"healthy": 1.0, "mild": 0.0, "moderate": 0.0, "severe": 0.0},
|
| 305 |
+
affected_area_percent=0.0
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
y_min, y_max = np.where(rows)[0][[0, -1]]
|
| 309 |
+
x_min, x_max = np.where(cols)[0][[0, -1]]
|
| 310 |
+
|
| 311 |
+
# Add padding
|
| 312 |
+
pad = 10
|
| 313 |
+
y_min = max(0, y_min - pad)
|
| 314 |
+
y_max = min(img_array.shape[0], y_max + pad)
|
| 315 |
+
x_min = max(0, x_min - pad)
|
| 316 |
+
x_max = min(img_array.shape[1], x_max + pad)
|
| 317 |
+
|
| 318 |
+
# Crop region
|
| 319 |
+
cropped = img_array[y_min:y_max, x_min:x_max]
|
| 320 |
+
cropped_mask = mask[y_min:y_max, x_min:x_max]
|
| 321 |
+
|
| 322 |
+
return self.classify(cropped, mask=cropped_mask)
|
| 323 |
+
|
| 324 |
+
def classify_batch(
|
| 325 |
+
self,
|
| 326 |
+
images: List[Union[str, Path, Image.Image, np.ndarray]],
|
| 327 |
+
masks: Optional[List[np.ndarray]] = None,
|
| 328 |
+
batch_size: int = 16
|
| 329 |
+
) -> List[SeverityPrediction]:
|
| 330 |
+
"""
|
| 331 |
+
Classify multiple images in batches.
|
| 332 |
+
|
| 333 |
+
Args:
|
| 334 |
+
images: List of images to classify
|
| 335 |
+
masks: Optional list of masks for each image
|
| 336 |
+
batch_size: Batch size for inference
|
| 337 |
+
|
| 338 |
+
Returns:
|
| 339 |
+
List of SeverityPrediction for each image
|
| 340 |
+
"""
|
| 341 |
+
self.load_model()
|
| 342 |
+
|
| 343 |
+
results = []
|
| 344 |
+
|
| 345 |
+
for i in range(0, len(images), batch_size):
|
| 346 |
+
batch_images = images[i:i + batch_size]
|
| 347 |
+
batch_masks = masks[i:i + batch_size] if masks else [None] * len(batch_images)
|
| 348 |
+
|
| 349 |
+
# Preprocess batch
|
| 350 |
+
tensors = [self.preprocess_image(img) for img in batch_images]
|
| 351 |
+
batch_tensor = torch.cat(tensors, dim=0).to(self.device)
|
| 352 |
+
|
| 353 |
+
# Predict
|
| 354 |
+
with torch.no_grad():
|
| 355 |
+
preds, probs = self.model.predict(batch_tensor)
|
| 356 |
+
|
| 357 |
+
# Format results
|
| 358 |
+
for j, (pred, prob) in enumerate(zip(preds, probs)):
|
| 359 |
+
mask = batch_masks[j]
|
| 360 |
+
affected_percent = 0.0
|
| 361 |
+
if mask is not None:
|
| 362 |
+
affected_percent = (mask.sum() / mask.size) * 100
|
| 363 |
+
|
| 364 |
+
severity_level = pred.item()
|
| 365 |
+
|
| 366 |
+
results.append(SeverityPrediction(
|
| 367 |
+
severity_level=severity_level,
|
| 368 |
+
severity_label=SEVERITY_LABELS[severity_level],
|
| 369 |
+
confidence=prob[severity_level].item(),
|
| 370 |
+
probabilities={
|
| 371 |
+
SEVERITY_LABELS[k]: prob[k].item()
|
| 372 |
+
for k in range(4)
|
| 373 |
+
},
|
| 374 |
+
affected_area_percent=affected_percent
|
| 375 |
+
))
|
| 376 |
+
|
| 377 |
+
return results
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
class PlantDiseaseDataset(Dataset):
|
| 381 |
+
"""
|
| 382 |
+
Dataset class for training severity classifier.
|
| 383 |
+
|
| 384 |
+
Expected folder structure:
|
| 385 |
+
data_root/
|
| 386 |
+
healthy/
|
| 387 |
+
image1.jpg
|
| 388 |
+
image2.jpg
|
| 389 |
+
mild/
|
| 390 |
+
...
|
| 391 |
+
moderate/
|
| 392 |
+
...
|
| 393 |
+
severe/
|
| 394 |
+
...
|
| 395 |
+
"""
|
| 396 |
+
|
| 397 |
+
def __init__(
|
| 398 |
+
self,
|
| 399 |
+
data_root: str,
|
| 400 |
+
transform: Optional[transforms.Compose] = None,
|
| 401 |
+
split: str = "train"
|
| 402 |
+
):
|
| 403 |
+
self.data_root = Path(data_root)
|
| 404 |
+
self.transform = transform
|
| 405 |
+
self.split = split
|
| 406 |
+
|
| 407 |
+
# Collect image paths and labels
|
| 408 |
+
self.samples = []
|
| 409 |
+
|
| 410 |
+
for label_idx, label_name in SEVERITY_LABELS.items():
|
| 411 |
+
label_dir = self.data_root / label_name
|
| 412 |
+
if label_dir.exists():
|
| 413 |
+
for img_path in label_dir.glob("*.jpg"):
|
| 414 |
+
self.samples.append((img_path, label_idx))
|
| 415 |
+
for img_path in label_dir.glob("*.png"):
|
| 416 |
+
self.samples.append((img_path, label_idx))
|
| 417 |
+
|
| 418 |
+
logger.info(f"Loaded {len(self.samples)} samples for {split}")
|
| 419 |
+
|
| 420 |
+
def __len__(self) -> int:
|
| 421 |
+
return len(self.samples)
|
| 422 |
+
|
| 423 |
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
|
| 424 |
+
img_path, label = self.samples[idx]
|
| 425 |
+
|
| 426 |
+
image = Image.open(img_path).convert("RGB")
|
| 427 |
+
|
| 428 |
+
if self.transform:
|
| 429 |
+
image = self.transform(image)
|
| 430 |
+
|
| 431 |
+
return image, label
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def train_classifier(
|
| 435 |
+
train_data_root: str,
|
| 436 |
+
val_data_root: str,
|
| 437 |
+
output_dir: str,
|
| 438 |
+
backbone: str = "efficientnet_b0",
|
| 439 |
+
epochs: int = 50,
|
| 440 |
+
batch_size: int = 32,
|
| 441 |
+
learning_rate: float = 1e-4,
|
| 442 |
+
device: str = "cuda"
|
| 443 |
+
):
|
| 444 |
+
"""
|
| 445 |
+
Train the severity classifier.
|
| 446 |
+
|
| 447 |
+
Args:
|
| 448 |
+
train_data_root: Path to training data
|
| 449 |
+
val_data_root: Path to validation data
|
| 450 |
+
output_dir: Where to save checkpoints
|
| 451 |
+
backbone: Model backbone to use
|
| 452 |
+
epochs: Number of training epochs
|
| 453 |
+
batch_size: Training batch size
|
| 454 |
+
learning_rate: Initial learning rate
|
| 455 |
+
device: Device to train on
|
| 456 |
+
"""
|
| 457 |
+
output_dir = Path(output_dir)
|
| 458 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 459 |
+
|
| 460 |
+
# Setup classifier for transforms
|
| 461 |
+
classifier = SeverityClassifier()
|
| 462 |
+
|
| 463 |
+
# Create datasets
|
| 464 |
+
train_dataset = PlantDiseaseDataset(
|
| 465 |
+
train_data_root,
|
| 466 |
+
transform=classifier.train_transform,
|
| 467 |
+
split="train"
|
| 468 |
+
)
|
| 469 |
+
val_dataset = PlantDiseaseDataset(
|
| 470 |
+
val_data_root,
|
| 471 |
+
transform=classifier.transform,
|
| 472 |
+
split="val"
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
# Create dataloaders
|
| 476 |
+
train_loader = DataLoader(
|
| 477 |
+
train_dataset,
|
| 478 |
+
batch_size=batch_size,
|
| 479 |
+
shuffle=True,
|
| 480 |
+
num_workers=4,
|
| 481 |
+
pin_memory=True
|
| 482 |
+
)
|
| 483 |
+
val_loader = DataLoader(
|
| 484 |
+
val_dataset,
|
| 485 |
+
batch_size=batch_size,
|
| 486 |
+
shuffle=False,
|
| 487 |
+
num_workers=4
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
# Initialize model
|
| 491 |
+
model = SeverityClassifierCNN(
|
| 492 |
+
num_classes=4,
|
| 493 |
+
backbone=backbone,
|
| 494 |
+
pretrained=True
|
| 495 |
+
).to(device)
|
| 496 |
+
|
| 497 |
+
# Loss and optimizer
|
| 498 |
+
criterion = nn.CrossEntropyLoss()
|
| 499 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
| 500 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
|
| 501 |
+
|
| 502 |
+
best_val_acc = 0.0
|
| 503 |
+
|
| 504 |
+
for epoch in range(epochs):
|
| 505 |
+
# Training
|
| 506 |
+
model.train()
|
| 507 |
+
train_loss = 0.0
|
| 508 |
+
train_correct = 0
|
| 509 |
+
train_total = 0
|
| 510 |
+
|
| 511 |
+
for images, labels in train_loader:
|
| 512 |
+
images, labels = images.to(device), labels.to(device)
|
| 513 |
+
|
| 514 |
+
optimizer.zero_grad()
|
| 515 |
+
outputs = model(images)
|
| 516 |
+
loss = criterion(outputs, labels)
|
| 517 |
+
loss.backward()
|
| 518 |
+
optimizer.step()
|
| 519 |
+
|
| 520 |
+
train_loss += loss.item()
|
| 521 |
+
_, predicted = outputs.max(1)
|
| 522 |
+
train_total += labels.size(0)
|
| 523 |
+
train_correct += predicted.eq(labels).sum().item()
|
| 524 |
+
|
| 525 |
+
# Validation
|
| 526 |
+
model.eval()
|
| 527 |
+
val_loss = 0.0
|
| 528 |
+
val_correct = 0
|
| 529 |
+
val_total = 0
|
| 530 |
+
|
| 531 |
+
with torch.no_grad():
|
| 532 |
+
for images, labels in val_loader:
|
| 533 |
+
images, labels = images.to(device), labels.to(device)
|
| 534 |
+
outputs = model(images)
|
| 535 |
+
loss = criterion(outputs, labels)
|
| 536 |
+
|
| 537 |
+
val_loss += loss.item()
|
| 538 |
+
_, predicted = outputs.max(1)
|
| 539 |
+
val_total += labels.size(0)
|
| 540 |
+
val_correct += predicted.eq(labels).sum().item()
|
| 541 |
+
|
| 542 |
+
train_acc = 100. * train_correct / train_total
|
| 543 |
+
val_acc = 100. * val_correct / val_total
|
| 544 |
+
|
| 545 |
+
scheduler.step()
|
| 546 |
+
|
| 547 |
+
logger.info(
|
| 548 |
+
f"Epoch {epoch+1}/{epochs} - "
|
| 549 |
+
f"Train Loss: {train_loss/len(train_loader):.4f}, Train Acc: {train_acc:.2f}% - "
|
| 550 |
+
f"Val Loss: {val_loss/len(val_loader):.4f}, Val Acc: {val_acc:.2f}%"
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
# Save best model
|
| 554 |
+
if val_acc > best_val_acc:
|
| 555 |
+
best_val_acc = val_acc
|
| 556 |
+
torch.save({
|
| 557 |
+
"epoch": epoch,
|
| 558 |
+
"model_state_dict": model.state_dict(),
|
| 559 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
| 560 |
+
"val_acc": val_acc,
|
| 561 |
+
"backbone": backbone
|
| 562 |
+
}, output_dir / "best.pt")
|
| 563 |
+
logger.info(f"Saved best model with val_acc: {val_acc:.2f}%")
|
| 564 |
+
|
| 565 |
+
# Save final model
|
| 566 |
+
torch.save({
|
| 567 |
+
"epoch": epochs,
|
| 568 |
+
"model_state_dict": model.state_dict(),
|
| 569 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
| 570 |
+
"val_acc": val_acc,
|
| 571 |
+
"backbone": backbone
|
| 572 |
+
}, output_dir / "final.pt")
|
| 573 |
+
|
| 574 |
+
logger.info(f"Training complete. Best val_acc: {best_val_acc:.2f}%")
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
if __name__ == "__main__":
|
| 578 |
+
# Quick test
|
| 579 |
+
classifier = SeverityClassifier()
|
| 580 |
+
|
| 581 |
+
# Create a test image
|
| 582 |
+
test_image = Image.new("RGB", (224, 224), color=(139, 69, 19)) # Brown
|
| 583 |
+
|
| 584 |
+
# Test classification (will use random weights without checkpoint)
|
| 585 |
+
result = classifier.classify(test_image)
|
| 586 |
+
|
| 587 |
+
print(f"Severity: {result.severity_label}")
|
| 588 |
+
print(f"Confidence: {result.confidence:.2f}")
|
| 589 |
+
print(f"Probabilities: {result.probabilities}")
|
| 590 |
+
print(f"Description: {SEVERITY_DESCRIPTIONS[result.severity_level]}")
|
src/treatment_recommender.py
ADDED
|
@@ -0,0 +1,525 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Treatment Recommender Module for CropDoctor-Semantic
|
| 3 |
+
=====================================================
|
| 4 |
+
|
| 5 |
+
This module uses Claude API to generate contextual treatment
|
| 6 |
+
recommendations based on disease diagnosis results.
|
| 7 |
+
|
| 8 |
+
Features:
|
| 9 |
+
- Disease identification from symptoms
|
| 10 |
+
- Treatment recommendations (organic/chemical)
|
| 11 |
+
- Preventive measures
|
| 12 |
+
- Timing and application guidance
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from typing import Dict, List, Optional, Union
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
import json
|
| 19 |
+
import logging
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class TreatmentRecommendation:
|
| 27 |
+
"""Container for treatment recommendations."""
|
| 28 |
+
disease_name: str
|
| 29 |
+
disease_type: str # fungal, bacterial, viral, pest, nutrient
|
| 30 |
+
confidence: float
|
| 31 |
+
symptoms_matched: List[str]
|
| 32 |
+
organic_treatments: List[str]
|
| 33 |
+
chemical_treatments: List[str]
|
| 34 |
+
preventive_measures: List[str]
|
| 35 |
+
timing: str
|
| 36 |
+
urgency: str # low, medium, high, critical
|
| 37 |
+
additional_notes: str
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# Disease knowledge base for offline fallback
|
| 41 |
+
DISEASE_DATABASE = {
|
| 42 |
+
"powdery_mildew": {
|
| 43 |
+
"name": "Powdery Mildew",
|
| 44 |
+
"type": "fungal",
|
| 45 |
+
"symptoms": ["white powdery coating", "curled leaves", "stunted growth"],
|
| 46 |
+
"organic": [
|
| 47 |
+
"Neem oil spray (1 tbsp per liter water)",
|
| 48 |
+
"Baking soda solution (1 tsp per liter + few drops soap)",
|
| 49 |
+
"Milk spray (40% milk, 60% water)",
|
| 50 |
+
"Sulfur-based fungicide"
|
| 51 |
+
],
|
| 52 |
+
"chemical": [
|
| 53 |
+
"Myclobutanil",
|
| 54 |
+
"Triadimefon",
|
| 55 |
+
"Propiconazole"
|
| 56 |
+
],
|
| 57 |
+
"prevention": [
|
| 58 |
+
"Improve air circulation",
|
| 59 |
+
"Avoid overhead watering",
|
| 60 |
+
"Remove infected plant parts",
|
| 61 |
+
"Plant resistant varieties"
|
| 62 |
+
],
|
| 63 |
+
"timing": "Apply at first sign of infection, repeat every 7-14 days"
|
| 64 |
+
},
|
| 65 |
+
"leaf_spot": {
|
| 66 |
+
"name": "Leaf Spot Disease",
|
| 67 |
+
"type": "fungal",
|
| 68 |
+
"symptoms": ["brown spots", "circular lesions", "yellow halos"],
|
| 69 |
+
"organic": [
|
| 70 |
+
"Copper-based fungicide",
|
| 71 |
+
"Neem oil treatment",
|
| 72 |
+
"Remove and destroy infected leaves"
|
| 73 |
+
],
|
| 74 |
+
"chemical": [
|
| 75 |
+
"Chlorothalonil",
|
| 76 |
+
"Mancozeb",
|
| 77 |
+
"Azoxystrobin"
|
| 78 |
+
],
|
| 79 |
+
"prevention": [
|
| 80 |
+
"Water at soil level",
|
| 81 |
+
"Mulch around plants",
|
| 82 |
+
"Rotate crops annually",
|
| 83 |
+
"Maintain proper spacing"
|
| 84 |
+
],
|
| 85 |
+
"timing": "Begin treatment early, apply every 7-10 days during wet weather"
|
| 86 |
+
},
|
| 87 |
+
"bacterial_blight": {
|
| 88 |
+
"name": "Bacterial Blight",
|
| 89 |
+
"type": "bacterial",
|
| 90 |
+
"symptoms": ["water-soaked lesions", "angular spots", "wilting"],
|
| 91 |
+
"organic": [
|
| 92 |
+
"Copper hydroxide spray",
|
| 93 |
+
"Remove infected plants",
|
| 94 |
+
"Improve drainage"
|
| 95 |
+
],
|
| 96 |
+
"chemical": [
|
| 97 |
+
"Streptomycin sulfate",
|
| 98 |
+
"Oxytetracycline",
|
| 99 |
+
"Copper-based bactericides"
|
| 100 |
+
],
|
| 101 |
+
"prevention": [
|
| 102 |
+
"Use disease-free seeds",
|
| 103 |
+
"Avoid working with wet plants",
|
| 104 |
+
"Sanitize tools between plants",
|
| 105 |
+
"Practice crop rotation"
|
| 106 |
+
],
|
| 107 |
+
"timing": "Apply preventively or at first symptoms, difficult to cure once established"
|
| 108 |
+
},
|
| 109 |
+
"viral_mosaic": {
|
| 110 |
+
"name": "Viral Mosaic Disease",
|
| 111 |
+
"type": "viral",
|
| 112 |
+
"symptoms": ["mosaic pattern", "mottled leaves", "distorted growth"],
|
| 113 |
+
"organic": [
|
| 114 |
+
"Remove and destroy infected plants",
|
| 115 |
+
"Control aphid vectors",
|
| 116 |
+
"Use reflective mulch"
|
| 117 |
+
],
|
| 118 |
+
"chemical": [
|
| 119 |
+
"No direct treatment available",
|
| 120 |
+
"Control vectors with insecticides"
|
| 121 |
+
],
|
| 122 |
+
"prevention": [
|
| 123 |
+
"Plant resistant varieties",
|
| 124 |
+
"Control aphid populations",
|
| 125 |
+
"Remove weeds that harbor virus",
|
| 126 |
+
"Sanitize tools frequently"
|
| 127 |
+
],
|
| 128 |
+
"timing": "Prevention is key - no cure available once infected"
|
| 129 |
+
},
|
| 130 |
+
"nutrient_deficiency": {
|
| 131 |
+
"name": "Nutrient Deficiency",
|
| 132 |
+
"type": "nutrient",
|
| 133 |
+
"symptoms": ["chlorosis", "yellowing", "purple coloration"],
|
| 134 |
+
"organic": [
|
| 135 |
+
"Compost application",
|
| 136 |
+
"Foliar feeding with seaweed extract",
|
| 137 |
+
"Fish emulsion fertilizer"
|
| 138 |
+
],
|
| 139 |
+
"chemical": [
|
| 140 |
+
"Balanced NPK fertilizer",
|
| 141 |
+
"Iron chelate for iron deficiency",
|
| 142 |
+
"Epsom salt for magnesium deficiency"
|
| 143 |
+
],
|
| 144 |
+
"prevention": [
|
| 145 |
+
"Regular soil testing",
|
| 146 |
+
"Maintain soil pH 6.0-7.0",
|
| 147 |
+
"Add organic matter annually",
|
| 148 |
+
"Mulch to retain nutrients"
|
| 149 |
+
],
|
| 150 |
+
"timing": "Apply fertilizers during active growth, avoid over-fertilization"
|
| 151 |
+
},
|
| 152 |
+
"aphid_infestation": {
|
| 153 |
+
"name": "Aphid Infestation",
|
| 154 |
+
"type": "pest",
|
| 155 |
+
"symptoms": ["sticky residue", "curled leaves", "visible insects"],
|
| 156 |
+
"organic": [
|
| 157 |
+
"Strong water spray to dislodge",
|
| 158 |
+
"Neem oil spray",
|
| 159 |
+
"Insecticidal soap",
|
| 160 |
+
"Release ladybugs or lacewings"
|
| 161 |
+
],
|
| 162 |
+
"chemical": [
|
| 163 |
+
"Pyrethrin-based insecticides",
|
| 164 |
+
"Imidacloprid (systemic)",
|
| 165 |
+
"Malathion"
|
| 166 |
+
],
|
| 167 |
+
"prevention": [
|
| 168 |
+
"Encourage beneficial insects",
|
| 169 |
+
"Remove weeds",
|
| 170 |
+
"Avoid excess nitrogen fertilization",
|
| 171 |
+
"Use reflective mulch"
|
| 172 |
+
],
|
| 173 |
+
"timing": "Treat early when populations are low, monitor regularly"
|
| 174 |
+
}
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class TreatmentRecommender:
|
| 179 |
+
"""
|
| 180 |
+
Generate treatment recommendations using LLM.
|
| 181 |
+
|
| 182 |
+
Uses Claude API for intelligent, contextual recommendations
|
| 183 |
+
with fallback to local knowledge base.
|
| 184 |
+
|
| 185 |
+
Example:
|
| 186 |
+
>>> recommender = TreatmentRecommender(api_key="your_key")
|
| 187 |
+
>>> result = recommender.get_recommendation(
|
| 188 |
+
... symptoms=["brown spots", "yellowing"],
|
| 189 |
+
... severity="moderate",
|
| 190 |
+
... plant_species="tomato"
|
| 191 |
+
... )
|
| 192 |
+
>>> print(result.disease_name)
|
| 193 |
+
>>> print(result.organic_treatments)
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
def __init__(
|
| 197 |
+
self,
|
| 198 |
+
api_key: Optional[str] = None,
|
| 199 |
+
use_llm: bool = True,
|
| 200 |
+
model: str = "claude-sonnet-4-20250514"
|
| 201 |
+
):
|
| 202 |
+
"""
|
| 203 |
+
Initialize treatment recommender.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
api_key: Anthropic API key (uses env var ANTHROPIC_API_KEY if None)
|
| 207 |
+
use_llm: Whether to use LLM for recommendations
|
| 208 |
+
model: Claude model to use
|
| 209 |
+
"""
|
| 210 |
+
self.use_llm = use_llm
|
| 211 |
+
self.model = model
|
| 212 |
+
self.client = None
|
| 213 |
+
|
| 214 |
+
if use_llm:
|
| 215 |
+
api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
|
| 216 |
+
if api_key:
|
| 217 |
+
try:
|
| 218 |
+
import anthropic
|
| 219 |
+
self.client = anthropic.Anthropic(api_key=api_key)
|
| 220 |
+
logger.info("Anthropic client initialized successfully")
|
| 221 |
+
except ImportError:
|
| 222 |
+
logger.warning("anthropic package not installed, using offline mode")
|
| 223 |
+
self.use_llm = False
|
| 224 |
+
else:
|
| 225 |
+
logger.warning("No API key provided, using offline knowledge base")
|
| 226 |
+
self.use_llm = False
|
| 227 |
+
|
| 228 |
+
def _build_prompt(
|
| 229 |
+
self,
|
| 230 |
+
symptoms: List[str],
|
| 231 |
+
severity: str,
|
| 232 |
+
plant_species: Optional[str] = None,
|
| 233 |
+
affected_area_percent: Optional[float] = None,
|
| 234 |
+
additional_context: Optional[str] = None
|
| 235 |
+
) -> str:
|
| 236 |
+
"""Build the prompt for Claude."""
|
| 237 |
+
|
| 238 |
+
prompt = f"""You are an expert plant pathologist providing diagnosis and treatment recommendations.
|
| 239 |
+
|
| 240 |
+
Based on the following observations, identify the most likely disease and provide treatment recommendations.
|
| 241 |
+
|
| 242 |
+
## Observations:
|
| 243 |
+
- **Symptoms detected**: {', '.join(symptoms)}
|
| 244 |
+
- **Severity level**: {severity}
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
if plant_species:
|
| 248 |
+
prompt += f"- **Plant species**: {plant_species}\n"
|
| 249 |
+
|
| 250 |
+
if affected_area_percent is not None:
|
| 251 |
+
prompt += f"- **Affected area**: {affected_area_percent:.1f}%\n"
|
| 252 |
+
|
| 253 |
+
if additional_context:
|
| 254 |
+
prompt += f"- **Additional context**: {additional_context}\n"
|
| 255 |
+
|
| 256 |
+
prompt += """
|
| 257 |
+
## Please provide:
|
| 258 |
+
1. **Disease identification**: Name and type (fungal/bacterial/viral/pest/nutrient)
|
| 259 |
+
2. **Confidence level**: How confident are you in this diagnosis (0-100%)
|
| 260 |
+
3. **Organic treatments**: List 3-5 organic/natural treatment options
|
| 261 |
+
4. **Chemical treatments**: List 2-3 chemical treatment options (if needed)
|
| 262 |
+
5. **Preventive measures**: List 3-5 prevention strategies
|
| 263 |
+
6. **Timing**: When and how often to apply treatments
|
| 264 |
+
7. **Urgency**: How urgent is treatment (low/medium/high/critical)
|
| 265 |
+
8. **Additional notes**: Any important considerations
|
| 266 |
+
|
| 267 |
+
Format your response as JSON with these exact keys:
|
| 268 |
+
{
|
| 269 |
+
"disease_name": "string",
|
| 270 |
+
"disease_type": "string",
|
| 271 |
+
"confidence": number,
|
| 272 |
+
"symptoms_matched": ["string"],
|
| 273 |
+
"organic_treatments": ["string"],
|
| 274 |
+
"chemical_treatments": ["string"],
|
| 275 |
+
"preventive_measures": ["string"],
|
| 276 |
+
"timing": "string",
|
| 277 |
+
"urgency": "string",
|
| 278 |
+
"additional_notes": "string"
|
| 279 |
+
}
|
| 280 |
+
"""
|
| 281 |
+
return prompt
|
| 282 |
+
|
| 283 |
+
def get_recommendation(
|
| 284 |
+
self,
|
| 285 |
+
symptoms: List[str],
|
| 286 |
+
severity: str = "moderate",
|
| 287 |
+
plant_species: Optional[str] = None,
|
| 288 |
+
affected_area_percent: Optional[float] = None,
|
| 289 |
+
additional_context: Optional[str] = None
|
| 290 |
+
) -> TreatmentRecommendation:
|
| 291 |
+
"""
|
| 292 |
+
Get treatment recommendation for detected symptoms.
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
symptoms: List of detected symptoms
|
| 296 |
+
severity: Severity level (healthy, mild, moderate, severe)
|
| 297 |
+
plant_species: Optional plant species for more specific advice
|
| 298 |
+
affected_area_percent: Percentage of plant affected
|
| 299 |
+
additional_context: Any additional observations
|
| 300 |
+
|
| 301 |
+
Returns:
|
| 302 |
+
TreatmentRecommendation with diagnosis and treatment options
|
| 303 |
+
"""
|
| 304 |
+
if self.use_llm and self.client:
|
| 305 |
+
return self._get_llm_recommendation(
|
| 306 |
+
symptoms, severity, plant_species,
|
| 307 |
+
affected_area_percent, additional_context
|
| 308 |
+
)
|
| 309 |
+
else:
|
| 310 |
+
return self._get_offline_recommendation(symptoms, severity)
|
| 311 |
+
|
| 312 |
+
def _get_llm_recommendation(
|
| 313 |
+
self,
|
| 314 |
+
symptoms: List[str],
|
| 315 |
+
severity: str,
|
| 316 |
+
plant_species: Optional[str],
|
| 317 |
+
affected_area_percent: Optional[float],
|
| 318 |
+
additional_context: Optional[str]
|
| 319 |
+
) -> TreatmentRecommendation:
|
| 320 |
+
"""Get recommendation from Claude API."""
|
| 321 |
+
|
| 322 |
+
prompt = self._build_prompt(
|
| 323 |
+
symptoms, severity, plant_species,
|
| 324 |
+
affected_area_percent, additional_context
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
try:
|
| 328 |
+
response = self.client.messages.create(
|
| 329 |
+
model=self.model,
|
| 330 |
+
max_tokens=1500,
|
| 331 |
+
messages=[
|
| 332 |
+
{"role": "user", "content": prompt}
|
| 333 |
+
]
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
# Parse JSON response
|
| 337 |
+
content = response.content[0].text
|
| 338 |
+
|
| 339 |
+
# Extract JSON from response
|
| 340 |
+
start = content.find('{')
|
| 341 |
+
end = content.rfind('}') + 1
|
| 342 |
+
if start >= 0 and end > start:
|
| 343 |
+
json_str = content[start:end]
|
| 344 |
+
data = json.loads(json_str)
|
| 345 |
+
|
| 346 |
+
return TreatmentRecommendation(
|
| 347 |
+
disease_name=data.get("disease_name", "Unknown"),
|
| 348 |
+
disease_type=data.get("disease_type", "unknown"),
|
| 349 |
+
confidence=data.get("confidence", 0) / 100,
|
| 350 |
+
symptoms_matched=data.get("symptoms_matched", symptoms),
|
| 351 |
+
organic_treatments=data.get("organic_treatments", []),
|
| 352 |
+
chemical_treatments=data.get("chemical_treatments", []),
|
| 353 |
+
preventive_measures=data.get("preventive_measures", []),
|
| 354 |
+
timing=data.get("timing", "Consult local extension service"),
|
| 355 |
+
urgency=data.get("urgency", "medium"),
|
| 356 |
+
additional_notes=data.get("additional_notes", "")
|
| 357 |
+
)
|
| 358 |
+
else:
|
| 359 |
+
logger.warning("Could not parse LLM response, using offline mode")
|
| 360 |
+
return self._get_offline_recommendation(symptoms, severity)
|
| 361 |
+
|
| 362 |
+
except Exception as e:
|
| 363 |
+
logger.error(f"LLM request failed: {e}")
|
| 364 |
+
return self._get_offline_recommendation(symptoms, severity)
|
| 365 |
+
|
| 366 |
+
def _get_offline_recommendation(
|
| 367 |
+
self,
|
| 368 |
+
symptoms: List[str],
|
| 369 |
+
severity: str
|
| 370 |
+
) -> TreatmentRecommendation:
|
| 371 |
+
"""Get recommendation from local knowledge base."""
|
| 372 |
+
|
| 373 |
+
# Simple symptom matching
|
| 374 |
+
best_match = None
|
| 375 |
+
best_score = 0
|
| 376 |
+
matched_symptoms = []
|
| 377 |
+
|
| 378 |
+
symptoms_lower = [s.lower() for s in symptoms]
|
| 379 |
+
|
| 380 |
+
for disease_key, disease_data in DISEASE_DATABASE.items():
|
| 381 |
+
score = 0
|
| 382 |
+
matches = []
|
| 383 |
+
|
| 384 |
+
for symptom in disease_data["symptoms"]:
|
| 385 |
+
for input_symptom in symptoms_lower:
|
| 386 |
+
if any(word in input_symptom for word in symptom.split()):
|
| 387 |
+
score += 1
|
| 388 |
+
matches.append(symptom)
|
| 389 |
+
break
|
| 390 |
+
|
| 391 |
+
if score > best_score:
|
| 392 |
+
best_score = score
|
| 393 |
+
best_match = disease_key
|
| 394 |
+
matched_symptoms = matches
|
| 395 |
+
|
| 396 |
+
if best_match is None:
|
| 397 |
+
# Default to general recommendation
|
| 398 |
+
return TreatmentRecommendation(
|
| 399 |
+
disease_name="Unidentified Condition",
|
| 400 |
+
disease_type="unknown",
|
| 401 |
+
confidence=0.3,
|
| 402 |
+
symptoms_matched=symptoms,
|
| 403 |
+
organic_treatments=[
|
| 404 |
+
"Apply broad-spectrum organic fungicide",
|
| 405 |
+
"Improve plant nutrition",
|
| 406 |
+
"Remove affected plant parts"
|
| 407 |
+
],
|
| 408 |
+
chemical_treatments=[
|
| 409 |
+
"Consult local agricultural extension",
|
| 410 |
+
"Get professional diagnosis"
|
| 411 |
+
],
|
| 412 |
+
preventive_measures=[
|
| 413 |
+
"Improve air circulation",
|
| 414 |
+
"Avoid overwatering",
|
| 415 |
+
"Maintain plant hygiene"
|
| 416 |
+
],
|
| 417 |
+
timing="Monitor closely and treat at first sign of worsening",
|
| 418 |
+
urgency="medium" if severity in ["mild", "moderate"] else "high",
|
| 419 |
+
additional_notes="Unable to identify specific disease. Consider professional diagnosis."
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
disease = DISEASE_DATABASE[best_match]
|
| 423 |
+
|
| 424 |
+
# Calculate urgency based on severity
|
| 425 |
+
urgency_map = {
|
| 426 |
+
"healthy": "low",
|
| 427 |
+
"mild": "medium",
|
| 428 |
+
"moderate": "high",
|
| 429 |
+
"severe": "critical"
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
return TreatmentRecommendation(
|
| 433 |
+
disease_name=disease["name"],
|
| 434 |
+
disease_type=disease["type"],
|
| 435 |
+
confidence=min(0.9, 0.5 + (best_score * 0.15)),
|
| 436 |
+
symptoms_matched=matched_symptoms,
|
| 437 |
+
organic_treatments=disease["organic"],
|
| 438 |
+
chemical_treatments=disease["chemical"],
|
| 439 |
+
preventive_measures=disease["prevention"],
|
| 440 |
+
timing=disease["timing"],
|
| 441 |
+
urgency=urgency_map.get(severity, "medium"),
|
| 442 |
+
additional_notes=f"Diagnosis based on {best_score} symptom matches."
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
def format_report(
|
| 446 |
+
self,
|
| 447 |
+
recommendation: TreatmentRecommendation,
|
| 448 |
+
include_chemical: bool = True
|
| 449 |
+
) -> str:
|
| 450 |
+
"""
|
| 451 |
+
Format recommendation as a readable report.
|
| 452 |
+
|
| 453 |
+
Args:
|
| 454 |
+
recommendation: TreatmentRecommendation to format
|
| 455 |
+
include_chemical: Whether to include chemical treatments
|
| 456 |
+
|
| 457 |
+
Returns:
|
| 458 |
+
Formatted string report
|
| 459 |
+
"""
|
| 460 |
+
urgency_emoji = {
|
| 461 |
+
"low": "π’",
|
| 462 |
+
"medium": "π‘",
|
| 463 |
+
"high": "π ",
|
| 464 |
+
"critical": "π΄"
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
report = f"""
|
| 468 |
+
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 469 |
+
β DIAGNOSTIC REPORT β
|
| 470 |
+
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 471 |
+
|
| 472 |
+
π DIAGNOSIS
|
| 473 |
+
Disease: {recommendation.disease_name}
|
| 474 |
+
Type: {recommendation.disease_type.capitalize()}
|
| 475 |
+
Confidence: {recommendation.confidence:.0%}
|
| 476 |
+
Urgency: {urgency_emoji.get(recommendation.urgency, 'βͺ')} {recommendation.urgency.upper()}
|
| 477 |
+
|
| 478 |
+
π SYMPTOMS MATCHED
|
| 479 |
+
"""
|
| 480 |
+
for symptom in recommendation.symptoms_matched:
|
| 481 |
+
report += f" β’ {symptom}\n"
|
| 482 |
+
|
| 483 |
+
report += """
|
| 484 |
+
πΏ ORGANIC TREATMENTS
|
| 485 |
+
"""
|
| 486 |
+
for treatment in recommendation.organic_treatments:
|
| 487 |
+
report += f" β’ {treatment}\n"
|
| 488 |
+
|
| 489 |
+
if include_chemical and recommendation.chemical_treatments:
|
| 490 |
+
report += """
|
| 491 |
+
π§ͺ CHEMICAL TREATMENTS
|
| 492 |
+
"""
|
| 493 |
+
for treatment in recommendation.chemical_treatments:
|
| 494 |
+
report += f" β’ {treatment}\n"
|
| 495 |
+
|
| 496 |
+
report += """
|
| 497 |
+
π‘οΈ PREVENTIVE MEASURES
|
| 498 |
+
"""
|
| 499 |
+
for measure in recommendation.preventive_measures:
|
| 500 |
+
report += f" β’ {measure}\n"
|
| 501 |
+
|
| 502 |
+
report += f"""
|
| 503 |
+
β° TIMING
|
| 504 |
+
{recommendation.timing}
|
| 505 |
+
|
| 506 |
+
π NOTES
|
| 507 |
+
{recommendation.additional_notes}
|
| 508 |
+
|
| 509 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 510 |
+
"""
|
| 511 |
+
return report
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
if __name__ == "__main__":
|
| 515 |
+
# Test without API
|
| 516 |
+
recommender = TreatmentRecommender(use_llm=False)
|
| 517 |
+
|
| 518 |
+
# Test with symptoms
|
| 519 |
+
result = recommender.get_recommendation(
|
| 520 |
+
symptoms=["brown spots", "yellow halos", "circular lesions"],
|
| 521 |
+
severity="moderate",
|
| 522 |
+
plant_species="tomato"
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
print(recommender.format_report(result))
|
src/visualization.py
ADDED
|
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Visualization Module for CropDoctor-Semantic
|
| 3 |
+
=============================================
|
| 4 |
+
|
| 5 |
+
This module provides visualization functions for:
|
| 6 |
+
- Segmentation masks overlay
|
| 7 |
+
- Severity heatmaps
|
| 8 |
+
- Diagnostic dashboards
|
| 9 |
+
- Comparison views
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
from PIL import Image
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
+
import matplotlib.patches as mpatches
|
| 16 |
+
from matplotlib.colors import LinearSegmentedColormap
|
| 17 |
+
from typing import List, Optional, Tuple, Union, Dict
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
import logging
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
# Color schemes
|
| 24 |
+
SEVERITY_COLORS = {
|
| 25 |
+
'healthy': '#2ECC71', # Green
|
| 26 |
+
'mild': '#F1C40F', # Yellow
|
| 27 |
+
'moderate': '#E67E22', # Orange
|
| 28 |
+
'severe': '#E74C3C' # Red
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
SYMPTOM_COLORS = [
|
| 32 |
+
'#E74C3C', # Red
|
| 33 |
+
'#9B59B6', # Purple
|
| 34 |
+
'#3498DB', # Blue
|
| 35 |
+
'#E67E22', # Orange
|
| 36 |
+
'#1ABC9C', # Teal
|
| 37 |
+
'#F39C12', # Yellow
|
| 38 |
+
'#D35400', # Dark Orange
|
| 39 |
+
'#8E44AD', # Dark Purple
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def create_diagnostic_visualization(
|
| 44 |
+
image: Union[str, Path, Image.Image, np.ndarray],
|
| 45 |
+
masks: Optional[np.ndarray] = None,
|
| 46 |
+
severity_label: str = "unknown",
|
| 47 |
+
disease_name: str = "Unknown",
|
| 48 |
+
affected_percent: float = 0.0,
|
| 49 |
+
prompt_labels: Optional[List[str]] = None,
|
| 50 |
+
figsize: Tuple[int, int] = (16, 6)
|
| 51 |
+
) -> plt.Figure:
|
| 52 |
+
"""
|
| 53 |
+
Create a comprehensive diagnostic visualization.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
image: Input image
|
| 57 |
+
masks: Segmentation masks array (N, H, W)
|
| 58 |
+
severity_label: Severity classification result
|
| 59 |
+
disease_name: Identified disease name
|
| 60 |
+
affected_percent: Percentage of affected area
|
| 61 |
+
prompt_labels: Labels for each mask
|
| 62 |
+
figsize: Figure size
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
matplotlib Figure object
|
| 66 |
+
"""
|
| 67 |
+
# Load image
|
| 68 |
+
if isinstance(image, (str, Path)):
|
| 69 |
+
image = Image.open(image).convert("RGB")
|
| 70 |
+
elif isinstance(image, np.ndarray):
|
| 71 |
+
image = Image.fromarray(image)
|
| 72 |
+
|
| 73 |
+
img_array = np.array(image)
|
| 74 |
+
|
| 75 |
+
# Create figure with subplots
|
| 76 |
+
fig, axes = plt.subplots(1, 3, figsize=figsize)
|
| 77 |
+
fig.suptitle(f'CropDoctor Diagnostic Report', fontsize=14, fontweight='bold')
|
| 78 |
+
|
| 79 |
+
# Panel 1: Original Image
|
| 80 |
+
axes[0].imshow(img_array)
|
| 81 |
+
axes[0].set_title('Original Image', fontsize=12)
|
| 82 |
+
axes[0].axis('off')
|
| 83 |
+
|
| 84 |
+
# Panel 2: Segmentation Overlay
|
| 85 |
+
if masks is not None and len(masks) > 0:
|
| 86 |
+
overlay = create_mask_overlay(img_array, masks, alpha=0.5)
|
| 87 |
+
axes[1].imshow(overlay)
|
| 88 |
+
|
| 89 |
+
# Create legend
|
| 90 |
+
if prompt_labels:
|
| 91 |
+
patches = []
|
| 92 |
+
for i, label in enumerate(prompt_labels[:len(SYMPTOM_COLORS)]):
|
| 93 |
+
color = SYMPTOM_COLORS[i % len(SYMPTOM_COLORS)]
|
| 94 |
+
patches.append(mpatches.Patch(color=color, label=label, alpha=0.7))
|
| 95 |
+
axes[1].legend(handles=patches, loc='upper right', fontsize=8)
|
| 96 |
+
else:
|
| 97 |
+
axes[1].imshow(img_array)
|
| 98 |
+
axes[1].text(0.5, 0.5, 'No disease regions detected',
|
| 99 |
+
transform=axes[1].transAxes, ha='center', va='center',
|
| 100 |
+
fontsize=12, color='green')
|
| 101 |
+
|
| 102 |
+
axes[1].set_title('Disease Detection', fontsize=12)
|
| 103 |
+
axes[1].axis('off')
|
| 104 |
+
|
| 105 |
+
# Panel 3: Diagnostic Summary
|
| 106 |
+
axes[2].axis('off')
|
| 107 |
+
|
| 108 |
+
# Create summary text
|
| 109 |
+
severity_color = SEVERITY_COLORS.get(severity_label.lower(), '#95A5A6')
|
| 110 |
+
|
| 111 |
+
summary_text = f"""
|
| 112 |
+
ββββββββββββββββββββββββββββββββββ
|
| 113 |
+
β DIAGNOSTIC SUMMARY β
|
| 114 |
+
ββββββββββββββββββββββββββββββββββ
|
| 115 |
+
|
| 116 |
+
π Disease: {disease_name}
|
| 117 |
+
|
| 118 |
+
β οΈ Severity: {severity_label.upper()}
|
| 119 |
+
|
| 120 |
+
π Affected Area: {affected_percent:.1f}%
|
| 121 |
+
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
# Add severity indicator
|
| 125 |
+
axes[2].text(0.5, 0.65, summary_text, transform=axes[2].transAxes,
|
| 126 |
+
fontsize=11, fontfamily='monospace',
|
| 127 |
+
verticalalignment='top', horizontalalignment='center')
|
| 128 |
+
|
| 129 |
+
# Add severity color bar
|
| 130 |
+
severity_bar = plt.Rectangle((0.15, 0.25), 0.7, 0.1,
|
| 131 |
+
facecolor=severity_color,
|
| 132 |
+
edgecolor='black',
|
| 133 |
+
transform=axes[2].transAxes)
|
| 134 |
+
axes[2].add_patch(severity_bar)
|
| 135 |
+
axes[2].text(0.5, 0.30, severity_label.upper(),
|
| 136 |
+
transform=axes[2].transAxes, ha='center', va='center',
|
| 137 |
+
fontsize=12, fontweight='bold', color='white')
|
| 138 |
+
|
| 139 |
+
# Add affected area progress bar
|
| 140 |
+
bar_width = 0.7 * (affected_percent / 100)
|
| 141 |
+
bg_bar = plt.Rectangle((0.15, 0.12), 0.7, 0.06,
|
| 142 |
+
facecolor='#EEEEEE', edgecolor='black',
|
| 143 |
+
transform=axes[2].transAxes)
|
| 144 |
+
progress_bar = plt.Rectangle((0.15, 0.12), max(0.01, bar_width), 0.06,
|
| 145 |
+
facecolor='#E74C3C',
|
| 146 |
+
transform=axes[2].transAxes)
|
| 147 |
+
axes[2].add_patch(bg_bar)
|
| 148 |
+
axes[2].add_patch(progress_bar)
|
| 149 |
+
axes[2].text(0.5, 0.08, f'Affected Area: {affected_percent:.1f}%',
|
| 150 |
+
transform=axes[2].transAxes, ha='center', fontsize=10)
|
| 151 |
+
|
| 152 |
+
plt.tight_layout()
|
| 153 |
+
|
| 154 |
+
return fig
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def create_mask_overlay(
|
| 158 |
+
image: np.ndarray,
|
| 159 |
+
masks: np.ndarray,
|
| 160 |
+
alpha: float = 0.5,
|
| 161 |
+
colors: Optional[List[str]] = None
|
| 162 |
+
) -> np.ndarray:
|
| 163 |
+
"""
|
| 164 |
+
Create an overlay of segmentation masks on an image.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
image: RGB image array (H, W, 3)
|
| 168 |
+
masks: Binary masks (N, H, W)
|
| 169 |
+
alpha: Transparency of overlay
|
| 170 |
+
colors: Optional list of colors for masks
|
| 171 |
+
|
| 172 |
+
Returns:
|
| 173 |
+
Image array with mask overlay
|
| 174 |
+
"""
|
| 175 |
+
if colors is None:
|
| 176 |
+
colors = SYMPTOM_COLORS
|
| 177 |
+
|
| 178 |
+
# Start with the original image
|
| 179 |
+
overlay = image.copy().astype(np.float32)
|
| 180 |
+
|
| 181 |
+
for i, mask in enumerate(masks):
|
| 182 |
+
if mask.any():
|
| 183 |
+
# Get color for this mask
|
| 184 |
+
color_hex = colors[i % len(colors)]
|
| 185 |
+
color_rgb = hex_to_rgb(color_hex)
|
| 186 |
+
|
| 187 |
+
# Create colored mask
|
| 188 |
+
colored_mask = np.zeros_like(overlay)
|
| 189 |
+
colored_mask[mask] = color_rgb
|
| 190 |
+
|
| 191 |
+
# Blend with overlay
|
| 192 |
+
mask_3d = np.stack([mask] * 3, axis=-1)
|
| 193 |
+
overlay = np.where(
|
| 194 |
+
mask_3d,
|
| 195 |
+
overlay * (1 - alpha) + colored_mask * alpha,
|
| 196 |
+
overlay
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
return overlay.astype(np.uint8)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def create_severity_heatmap(
|
| 203 |
+
image: Union[str, Path, Image.Image, np.ndarray],
|
| 204 |
+
severity_map: np.ndarray,
|
| 205 |
+
figsize: Tuple[int, int] = (12, 5)
|
| 206 |
+
) -> plt.Figure:
|
| 207 |
+
"""
|
| 208 |
+
Create a heatmap showing severity distribution across the image.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
image: Input image
|
| 212 |
+
severity_map: Array of severity values (H, W) with values 0-3
|
| 213 |
+
figsize: Figure size
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
matplotlib Figure object
|
| 217 |
+
"""
|
| 218 |
+
# Load image
|
| 219 |
+
if isinstance(image, (str, Path)):
|
| 220 |
+
image = Image.open(image).convert("RGB")
|
| 221 |
+
elif isinstance(image, np.ndarray):
|
| 222 |
+
image = Image.fromarray(image)
|
| 223 |
+
|
| 224 |
+
img_array = np.array(image)
|
| 225 |
+
|
| 226 |
+
# Create custom colormap
|
| 227 |
+
colors = ['#2ECC71', '#F1C40F', '#E67E22', '#E74C3C'] # Green to Red
|
| 228 |
+
cmap = LinearSegmentedColormap.from_list('severity', colors, N=256)
|
| 229 |
+
|
| 230 |
+
fig, axes = plt.subplots(1, 2, figsize=figsize)
|
| 231 |
+
|
| 232 |
+
# Original image
|
| 233 |
+
axes[0].imshow(img_array)
|
| 234 |
+
axes[0].set_title('Original Image')
|
| 235 |
+
axes[0].axis('off')
|
| 236 |
+
|
| 237 |
+
# Heatmap overlay
|
| 238 |
+
axes[1].imshow(img_array)
|
| 239 |
+
heatmap = axes[1].imshow(severity_map, cmap=cmap, alpha=0.6, vmin=0, vmax=3)
|
| 240 |
+
axes[1].set_title('Severity Heatmap')
|
| 241 |
+
axes[1].axis('off')
|
| 242 |
+
|
| 243 |
+
# Add colorbar
|
| 244 |
+
cbar = plt.colorbar(heatmap, ax=axes[1], fraction=0.046, pad=0.04)
|
| 245 |
+
cbar.set_ticks([0, 1, 2, 3])
|
| 246 |
+
cbar.set_ticklabels(['Healthy', 'Mild', 'Moderate', 'Severe'])
|
| 247 |
+
|
| 248 |
+
plt.tight_layout()
|
| 249 |
+
|
| 250 |
+
return fig
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def create_comparison_view(
|
| 254 |
+
images: List[Union[str, Path, Image.Image]],
|
| 255 |
+
results: List[Dict],
|
| 256 |
+
cols: int = 4,
|
| 257 |
+
figsize_per_image: Tuple[float, float] = (4, 5)
|
| 258 |
+
) -> plt.Figure:
|
| 259 |
+
"""
|
| 260 |
+
Create a grid comparison view of multiple diagnoses.
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
images: List of images
|
| 264 |
+
results: List of diagnostic results (dicts with 'severity_label', 'disease_name', etc.)
|
| 265 |
+
cols: Number of columns in grid
|
| 266 |
+
figsize_per_image: Size per image in the grid
|
| 267 |
+
|
| 268 |
+
Returns:
|
| 269 |
+
matplotlib Figure object
|
| 270 |
+
"""
|
| 271 |
+
n_images = len(images)
|
| 272 |
+
rows = (n_images + cols - 1) // cols
|
| 273 |
+
|
| 274 |
+
fig, axes = plt.subplots(
|
| 275 |
+
rows, cols,
|
| 276 |
+
figsize=(figsize_per_image[0] * cols, figsize_per_image[1] * rows)
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
if rows == 1:
|
| 280 |
+
axes = [axes]
|
| 281 |
+
if cols == 1:
|
| 282 |
+
axes = [[ax] for ax in axes]
|
| 283 |
+
|
| 284 |
+
for i, (img, result) in enumerate(zip(images, results)):
|
| 285 |
+
row = i // cols
|
| 286 |
+
col = i % cols
|
| 287 |
+
ax = axes[row][col] if rows > 1 else axes[col]
|
| 288 |
+
|
| 289 |
+
# Load image
|
| 290 |
+
if isinstance(img, (str, Path)):
|
| 291 |
+
img = Image.open(img).convert("RGB")
|
| 292 |
+
|
| 293 |
+
ax.imshow(img)
|
| 294 |
+
ax.axis('off')
|
| 295 |
+
|
| 296 |
+
# Add colored border based on severity
|
| 297 |
+
severity = result.get('severity_label', 'unknown')
|
| 298 |
+
color = SEVERITY_COLORS.get(severity.lower(), '#95A5A6')
|
| 299 |
+
|
| 300 |
+
for spine in ax.spines.values():
|
| 301 |
+
spine.set_edgecolor(color)
|
| 302 |
+
spine.set_linewidth(4)
|
| 303 |
+
spine.set_visible(True)
|
| 304 |
+
|
| 305 |
+
# Add label
|
| 306 |
+
ax.set_title(
|
| 307 |
+
f"{result.get('disease_name', 'Unknown')}\n{severity.upper()}",
|
| 308 |
+
fontsize=10,
|
| 309 |
+
color=color
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# Hide empty subplots
|
| 313 |
+
for i in range(n_images, rows * cols):
|
| 314 |
+
row = i // cols
|
| 315 |
+
col = i % cols
|
| 316 |
+
ax = axes[row][col] if rows > 1 else axes[col]
|
| 317 |
+
ax.axis('off')
|
| 318 |
+
ax.set_visible(False)
|
| 319 |
+
|
| 320 |
+
plt.tight_layout()
|
| 321 |
+
|
| 322 |
+
return fig
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def create_treatment_card(
|
| 326 |
+
result: Dict,
|
| 327 |
+
figsize: Tuple[int, int] = (8, 10)
|
| 328 |
+
) -> plt.Figure:
|
| 329 |
+
"""
|
| 330 |
+
Create a treatment recommendation card.
|
| 331 |
+
|
| 332 |
+
Args:
|
| 333 |
+
result: Diagnostic result dictionary
|
| 334 |
+
figsize: Figure size
|
| 335 |
+
|
| 336 |
+
Returns:
|
| 337 |
+
matplotlib Figure object
|
| 338 |
+
"""
|
| 339 |
+
fig, ax = plt.subplots(figsize=figsize)
|
| 340 |
+
ax.axis('off')
|
| 341 |
+
|
| 342 |
+
severity_color = SEVERITY_COLORS.get(
|
| 343 |
+
result.get('severity_label', 'unknown').lower(),
|
| 344 |
+
'#95A5A6'
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
# Title
|
| 348 |
+
ax.text(0.5, 0.95, 'πΏ TREATMENT CARD',
|
| 349 |
+
ha='center', va='top', fontsize=16, fontweight='bold',
|
| 350 |
+
transform=ax.transAxes)
|
| 351 |
+
|
| 352 |
+
# Disease info
|
| 353 |
+
disease_text = f"""
|
| 354 |
+
βββββββββββββββββββββββββββββββββββββββββββββ
|
| 355 |
+
β Disease: {result.get('disease_name', 'Unknown'):<32}β
|
| 356 |
+
β Type: {result.get('disease_type', 'unknown'):<35}β
|
| 357 |
+
β Severity: {result.get('severity_label', 'unknown').upper():<31}β
|
| 358 |
+
β Affected Area: {result.get('affected_area_percent', 0):.1f}%{' ' * 25}β
|
| 359 |
+
βββββββββββββββββββββββββββββββββββββββββββββ
|
| 360 |
+
"""
|
| 361 |
+
ax.text(0.5, 0.85, disease_text,
|
| 362 |
+
ha='center', va='top', fontfamily='monospace', fontsize=10,
|
| 363 |
+
transform=ax.transAxes)
|
| 364 |
+
|
| 365 |
+
# Treatments
|
| 366 |
+
y_pos = 0.60
|
| 367 |
+
|
| 368 |
+
# Organic treatments
|
| 369 |
+
ax.text(0.1, y_pos, 'π± ORGANIC TREATMENTS', fontweight='bold', fontsize=11,
|
| 370 |
+
transform=ax.transAxes)
|
| 371 |
+
y_pos -= 0.03
|
| 372 |
+
|
| 373 |
+
for treatment in result.get('organic_treatments', [])[:4]:
|
| 374 |
+
ax.text(0.12, y_pos, f'β’ {treatment[:50]}', fontsize=9,
|
| 375 |
+
transform=ax.transAxes)
|
| 376 |
+
y_pos -= 0.03
|
| 377 |
+
|
| 378 |
+
y_pos -= 0.02
|
| 379 |
+
|
| 380 |
+
# Chemical treatments
|
| 381 |
+
if result.get('chemical_treatments'):
|
| 382 |
+
ax.text(0.1, y_pos, 'π§ͺ CHEMICAL TREATMENTS', fontweight='bold', fontsize=11,
|
| 383 |
+
transform=ax.transAxes)
|
| 384 |
+
y_pos -= 0.03
|
| 385 |
+
|
| 386 |
+
for treatment in result.get('chemical_treatments', [])[:3]:
|
| 387 |
+
ax.text(0.12, y_pos, f'β’ {treatment[:50]}', fontsize=9,
|
| 388 |
+
transform=ax.transAxes)
|
| 389 |
+
y_pos -= 0.03
|
| 390 |
+
|
| 391 |
+
y_pos -= 0.02
|
| 392 |
+
|
| 393 |
+
# Prevention
|
| 394 |
+
ax.text(0.1, y_pos, 'π‘οΈ PREVENTION', fontweight='bold', fontsize=11,
|
| 395 |
+
transform=ax.transAxes)
|
| 396 |
+
y_pos -= 0.03
|
| 397 |
+
|
| 398 |
+
for measure in result.get('preventive_measures', [])[:4]:
|
| 399 |
+
ax.text(0.12, y_pos, f'β’ {measure[:50]}', fontsize=9,
|
| 400 |
+
transform=ax.transAxes)
|
| 401 |
+
y_pos -= 0.03
|
| 402 |
+
|
| 403 |
+
# Timing
|
| 404 |
+
y_pos -= 0.02
|
| 405 |
+
ax.text(0.1, y_pos, f"β° TIMING: {result.get('treatment_timing', 'Consult expert')[:60]}",
|
| 406 |
+
fontsize=9, transform=ax.transAxes)
|
| 407 |
+
|
| 408 |
+
# Add border
|
| 409 |
+
rect = plt.Rectangle((0.05, 0.05), 0.9, 0.92,
|
| 410 |
+
fill=False, edgecolor=severity_color, linewidth=3,
|
| 411 |
+
transform=ax.transAxes)
|
| 412 |
+
ax.add_patch(rect)
|
| 413 |
+
|
| 414 |
+
return fig
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def hex_to_rgb(hex_color: str) -> List[int]:
|
| 418 |
+
"""Convert hex color to RGB."""
|
| 419 |
+
hex_color = hex_color.lstrip('#')
|
| 420 |
+
return [int(hex_color[i:i+2], 16) for i in (0, 2, 4)]
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def save_visualization(
|
| 424 |
+
fig: plt.Figure,
|
| 425 |
+
output_path: Union[str, Path],
|
| 426 |
+
dpi: int = 150
|
| 427 |
+
):
|
| 428 |
+
"""Save figure to file."""
|
| 429 |
+
output_path = Path(output_path)
|
| 430 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 431 |
+
fig.savefig(output_path, dpi=dpi, bbox_inches='tight', facecolor='white')
|
| 432 |
+
plt.close(fig)
|
| 433 |
+
logger.info(f"Visualization saved to {output_path}")
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
if __name__ == "__main__":
|
| 437 |
+
# Test visualizations
|
| 438 |
+
import numpy as np
|
| 439 |
+
|
| 440 |
+
# Create test image
|
| 441 |
+
test_img = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
|
| 442 |
+
test_img[:, :, 1] = 139 # Greenish tint
|
| 443 |
+
|
| 444 |
+
# Create test masks
|
| 445 |
+
test_masks = np.zeros((2, 480, 640), dtype=bool)
|
| 446 |
+
test_masks[0, 100:200, 100:200] = True # Square mask
|
| 447 |
+
test_masks[1, 300:400, 400:500] = True # Another square
|
| 448 |
+
|
| 449 |
+
# Test diagnostic visualization
|
| 450 |
+
fig = create_diagnostic_visualization(
|
| 451 |
+
test_img,
|
| 452 |
+
test_masks,
|
| 453 |
+
severity_label="moderate",
|
| 454 |
+
disease_name="Leaf Spot Disease",
|
| 455 |
+
affected_percent=15.5,
|
| 456 |
+
prompt_labels=["brown spots", "yellowing"]
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
save_visualization(fig, "/tmp/test_diagnostic.png")
|
| 460 |
+
print("Test visualization saved to /tmp/test_diagnostic.png")
|