davidsv commited on
Commit
f8eb07d
Β·
1 Parent(s): d7b4d0c

Add disease detection app with RF-DETR and SAM2

Browse files
app.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ CropDoctor - Plant Disease Detection
4
+ Hugging Face Space Demo
5
+ """
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ from PIL import Image
10
+ from pathlib import Path
11
+ import torch
12
+
13
+ # Model paths
14
+ MODEL_DIR = Path("models")
15
+ RFDETR_CHECKPOINT = MODEL_DIR / "rfdetr" / "checkpoint_best_total.pth"
16
+ SAM2_CHECKPOINT = MODEL_DIR / "sam2" / "sam2.1_hiera_small.pt"
17
+
18
+ # Lazy loaded components
19
+ segmenter = None
20
+ leaf_segmenter = None
21
+
22
+
23
+ def load_models():
24
+ """Load models on first use."""
25
+ global segmenter, leaf_segmenter
26
+
27
+ if segmenter is not None:
28
+ return
29
+
30
+ print("Loading models...")
31
+
32
+ # Import here to avoid loading at startup
33
+ from src.sam3_segmentation import RFDETRSegmenter
34
+ from src.leaf_segmenter import SAM2LeafSegmenter
35
+
36
+ # Load RF-DETR for disease detection
37
+ segmenter = RFDETRSegmenter(
38
+ checkpoint_path=str(RFDETR_CHECKPOINT),
39
+ model_size="medium"
40
+ )
41
+
42
+ # Load SAM2 for leaf segmentation and mask refinement
43
+ leaf_segmenter = SAM2LeafSegmenter(
44
+ checkpoint_path=str(SAM2_CHECKPOINT)
45
+ )
46
+
47
+ print("Models loaded!")
48
+
49
+
50
+ def detect_disease(
51
+ image: np.ndarray,
52
+ use_leaf_segmentation: bool = True,
53
+ confidence_threshold: float = 0.3
54
+ ) -> tuple:
55
+ """
56
+ Detect plant diseases in an image.
57
+
58
+ Args:
59
+ image: Input image
60
+ use_leaf_segmentation: Whether to isolate leaf first
61
+ confidence_threshold: Detection confidence threshold
62
+
63
+ Returns:
64
+ Tuple of (annotated_image, detection_info)
65
+ """
66
+ if image is None:
67
+ return None, "Please upload an image"
68
+
69
+ load_models()
70
+
71
+ # Convert to PIL
72
+ pil_image = Image.fromarray(image)
73
+ original_image = image.copy()
74
+
75
+ # Step 1: Leaf segmentation (optional)
76
+ segmented_image = None
77
+ leaf_mask = None
78
+ if use_leaf_segmentation:
79
+ segmented_pil, leaf_mask = leaf_segmenter.auto_segment_leaf(
80
+ pil_image, return_mask=True
81
+ )
82
+ segmented_image = np.array(segmented_pil)
83
+ detection_input = segmented_pil
84
+ else:
85
+ detection_input = pil_image
86
+
87
+ # Step 2: Disease detection with RF-DETR
88
+ prompts = ["diseased plant tissue", "leaf spot", "disease symptom"]
89
+ seg_result = segmenter.segment_with_concepts(
90
+ detection_input,
91
+ prompts,
92
+ confidence_threshold=confidence_threshold
93
+ )
94
+
95
+ num_detections = len(seg_result.boxes)
96
+
97
+ # Step 3: Refine boxes to masks with SAM2
98
+ if num_detections > 0:
99
+ refined_masks = leaf_segmenter.refine_boxes_to_masks(
100
+ detection_input,
101
+ seg_result.boxes
102
+ )
103
+ else:
104
+ refined_masks = np.zeros((0, image.shape[0], image.shape[1]), dtype=bool)
105
+
106
+ # Create visualization
107
+ from src.visualization import create_mask_overlay
108
+
109
+ if use_leaf_segmentation and segmented_image is not None:
110
+ base_image = segmented_image
111
+ else:
112
+ base_image = original_image
113
+
114
+ if num_detections > 0:
115
+ annotated = create_mask_overlay(base_image, refined_masks, alpha=0.5)
116
+ else:
117
+ annotated = base_image
118
+
119
+ # Format detection info
120
+ if num_detections == 0:
121
+ info = "### No disease detected\n\nThe leaf appears healthy or no disease symptoms were found."
122
+ else:
123
+ # Calculate affected area
124
+ total_mask = np.zeros(refined_masks[0].shape, dtype=bool)
125
+ for mask in refined_masks:
126
+ total_mask |= mask
127
+
128
+ if use_leaf_segmentation and leaf_mask is not None:
129
+ affected_percent = (total_mask & leaf_mask).sum() / max(leaf_mask.sum(), 1) * 100
130
+ else:
131
+ affected_percent = total_mask.sum() / (image.shape[0] * image.shape[1]) * 100
132
+
133
+ info = f"""### Disease Detected
134
+
135
+ - **Regions found**: {num_detections}
136
+ - **Affected area**: {affected_percent:.1f}%
137
+ - **Confidence threshold**: {confidence_threshold:.0%}
138
+
139
+ The colored overlays show detected disease regions.
140
+ """
141
+
142
+ return annotated, info
143
+
144
+
145
+ def create_demo():
146
+ """Create Gradio interface."""
147
+
148
+ with gr.Blocks(title="CropDoctor - Plant Disease Detection") as demo:
149
+
150
+ gr.Markdown("""
151
+ # CropDoctor - Plant Disease Detection
152
+
153
+ Upload a plant leaf image to detect disease regions using AI.
154
+
155
+ **Models used:**
156
+ - **RF-DETR**: Fine-tuned object detector for disease localization
157
+ - **SAM2**: Segment Anything Model 2 for precise mask generation
158
+ """)
159
+
160
+ with gr.Row():
161
+ with gr.Column():
162
+ input_image = gr.Image(
163
+ label="Upload Plant Image",
164
+ type="numpy",
165
+ height=400
166
+ )
167
+
168
+ with gr.Row():
169
+ leaf_seg_checkbox = gr.Checkbox(
170
+ value=True,
171
+ label="Isolate leaf (SAM2)",
172
+ info="Segment leaf before detection to reduce false positives"
173
+ )
174
+ confidence_slider = gr.Slider(
175
+ minimum=0.1,
176
+ maximum=0.9,
177
+ value=0.3,
178
+ step=0.05,
179
+ label="Confidence Threshold"
180
+ )
181
+
182
+ detect_btn = gr.Button("Detect Disease", variant="primary", size="lg")
183
+
184
+ with gr.Column():
185
+ output_image = gr.Image(
186
+ label="Detection Result",
187
+ type="numpy",
188
+ height=400
189
+ )
190
+ detection_info = gr.Markdown(
191
+ label="Detection Info",
192
+ value="Upload an image and click 'Detect Disease' to see results."
193
+ )
194
+
195
+ # Event handler
196
+ detect_btn.click(
197
+ fn=detect_disease,
198
+ inputs=[input_image, leaf_seg_checkbox, confidence_slider],
199
+ outputs=[output_image, detection_info]
200
+ )
201
+
202
+ gr.Markdown("""
203
+ ---
204
+ **Note**: This demo uses models trained on the PlantVillage dataset.
205
+ For best results, use clear images of individual plant leaves.
206
+ """)
207
+
208
+ return demo
209
+
210
+
211
+ if __name__ == "__main__":
212
+ demo = create_demo()
213
+ demo.launch()
configs/sam3_config.yaml ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SAM 3 Configuration for CropDoctor-Semantic
2
+ # ============================================
3
+
4
+ model:
5
+ name: "sam3"
6
+ checkpoint: "models/sam3/sam3.pt"
7
+ device: "cuda" # cuda, cpu, mps
8
+ half_precision: true # Use FP16 for faster inference
9
+
10
+ inference:
11
+ # Confidence thresholds
12
+ confidence_threshold: 0.25
13
+ presence_threshold: 0.5
14
+
15
+ # Object limits
16
+ max_objects_per_prompt: 50
17
+ min_mask_area: 100 # pixels
18
+
19
+ # Post-processing
20
+ apply_nms: true
21
+ nms_threshold: 0.5
22
+
23
+ # Disease detection prompts (ordered by specificity)
24
+ prompts:
25
+ # General disease detection
26
+ general:
27
+ - "diseased plant tissue"
28
+ - "infected leaf area"
29
+ - "plant disease symptoms"
30
+ - "abnormal leaf coloration"
31
+
32
+ # Healthy reference
33
+ healthy:
34
+ - "healthy green leaf"
35
+ - "normal plant tissue"
36
+ - "unaffected leaf area"
37
+
38
+ # Fungal diseases
39
+ fungal:
40
+ - "powdery mildew coating"
41
+ - "rust pustules on leaf"
42
+ - "leaf spot lesions"
43
+ - "anthracnose dark lesions"
44
+ - "downy mildew"
45
+ - "gray mold"
46
+ - "scab lesions"
47
+
48
+ # Bacterial diseases
49
+ bacterial:
50
+ - "bacterial blight"
51
+ - "water-soaked lesions"
52
+ - "angular leaf spots"
53
+ - "bacterial canker"
54
+ - "soft rot"
55
+
56
+ # Viral diseases
57
+ viral:
58
+ - "mosaic pattern on leaf"
59
+ - "leaf curl symptoms"
60
+ - "yellowing veins"
61
+ - "ring spots"
62
+ - "stunted growth"
63
+
64
+ # Nutrient deficiency
65
+ nutrient:
66
+ - "chlorosis yellowing"
67
+ - "interveinal chlorosis"
68
+ - "purple leaf coloration"
69
+ - "brown leaf edges"
70
+ - "pale green leaves"
71
+
72
+ # Pest damage
73
+ pest:
74
+ - "insect chewing damage"
75
+ - "leaf mining trails"
76
+ - "aphid colony"
77
+ - "caterpillar damage"
78
+ - "mite damage stippling"
79
+ - "holes in leaves"
80
+
81
+ # Prompt combinations for comprehensive analysis
82
+ analysis_profiles:
83
+ quick_scan:
84
+ prompts:
85
+ - "diseased plant tissue"
86
+ - "healthy green leaf"
87
+ description: "Fast binary healthy/diseased detection"
88
+
89
+ standard:
90
+ prompts:
91
+ - "diseased plant tissue"
92
+ - "leaf spot lesions"
93
+ - "yellowing leaves"
94
+ - "pest damage"
95
+ description: "Standard multi-symptom analysis"
96
+
97
+ comprehensive:
98
+ prompts:
99
+ - "diseased plant tissue"
100
+ - "powdery mildew"
101
+ - "rust pustules"
102
+ - "bacterial blight"
103
+ - "mosaic pattern"
104
+ - "chlorosis"
105
+ - "insect damage"
106
+ - "healthy tissue"
107
+ description: "Full diagnostic scan"
108
+
109
+ pest_focused:
110
+ prompts:
111
+ - "insect chewing damage"
112
+ - "leaf mining trails"
113
+ - "aphid infestation"
114
+ - "caterpillar damage"
115
+ - "mite stippling"
116
+ description: "Pest-specific detection"
117
+
118
+ # Visualization settings
119
+ visualization:
120
+ mask_alpha: 0.5
121
+ colormap: "viridis"
122
+ show_confidence: true
123
+ save_format: "png"
124
+ dpi: 150
125
+
126
+ # Color scheme for severity
127
+ severity_colors:
128
+ healthy: [0, 255, 0] # Green
129
+ mild: [255, 255, 0] # Yellow
130
+ moderate: [255, 165, 0] # Orange
131
+ severe: [255, 0, 0] # Red
132
+
133
+ # Performance optimization
134
+ optimization:
135
+ batch_size: 1
136
+ num_workers: 4
137
+ prefetch_factor: 2
138
+ pin_memory: true
models/rfdetr/checkpoint_best_total.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e24e651f9f144db066cd372fc12f24f8d3a3400d9fddb514638ca513f65d3152
3
+ size 133680431
models/sam2/sam2.1_hiera_small.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d1aa6f30de5c92224f8172114de081d104bbd23dd9dc5c58996f0cad5dc4d38
3
+ size 184416285
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Requirements for Hugging Face Space
2
+ gradio>=4.0.0
3
+ torch>=2.0.0
4
+ torchvision>=0.15.0
5
+ numpy>=1.24.0
6
+ Pillow>=9.0.0
7
+ opencv-python-headless>=4.8.0
8
+ supervision>=0.16.0
9
+ rfdetr
10
+ sam2
11
+ scipy
12
+ PyYAML
src/__init__.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CropDoctor-Semantic: AI-Powered Plant Disease Diagnosis
3
+ ========================================================
4
+
5
+ A comprehensive pipeline for plant disease diagnosis using:
6
+ - SAM 3 (Segment Anything Model 3) for concept-based segmentation
7
+ - CNN-based severity classification
8
+ - LLM-powered treatment recommendations
9
+
10
+ Main Components:
11
+ - SAM3Segmenter: Zero-shot disease region segmentation
12
+ - SeverityClassifier: CNN for severity assessment
13
+ - TreatmentRecommender: Claude API integration for treatment advice
14
+ - CropDoctorPipeline: End-to-end diagnostic pipeline
15
+
16
+ Quick Start:
17
+ >>> from src.pipeline import CropDoctorPipeline
18
+ >>> pipeline = CropDoctorPipeline()
19
+ >>> result = pipeline.diagnose("path/to/leaf.jpg")
20
+ >>> print(result.disease_name, result.severity_label)
21
+
22
+ For more information, see the README.md or documentation.
23
+ """
24
+
25
+ __version__ = "0.1.0"
26
+ __author__ = "CropDoctor Team"
27
+
28
+ from .sam3_segmentation import (
29
+ SAM3Segmenter,
30
+ MockSAM3Segmenter,
31
+ create_segmenter,
32
+ SegmentationResult
33
+ )
34
+
35
+ from .severity_classifier import (
36
+ SeverityClassifier,
37
+ SeverityClassifierCNN,
38
+ SeverityPrediction,
39
+ SEVERITY_LABELS,
40
+ SEVERITY_DESCRIPTIONS
41
+ )
42
+
43
+ from .treatment_recommender import (
44
+ TreatmentRecommender,
45
+ TreatmentRecommendation,
46
+ DISEASE_DATABASE
47
+ )
48
+
49
+ from .pipeline import (
50
+ CropDoctorPipeline,
51
+ DiagnosticResult,
52
+ quick_diagnose
53
+ )
54
+
55
+ from .visualization import (
56
+ create_diagnostic_visualization,
57
+ create_mask_overlay,
58
+ create_severity_heatmap,
59
+ create_comparison_view,
60
+ create_treatment_card,
61
+ save_visualization
62
+ )
63
+
64
+ __all__ = [
65
+ # Version
66
+ '__version__',
67
+
68
+ # Segmentation
69
+ 'SAM3Segmenter',
70
+ 'MockSAM3Segmenter',
71
+ 'create_segmenter',
72
+ 'SegmentationResult',
73
+
74
+ # Classification
75
+ 'SeverityClassifier',
76
+ 'SeverityClassifierCNN',
77
+ 'SeverityPrediction',
78
+ 'SEVERITY_LABELS',
79
+ 'SEVERITY_DESCRIPTIONS',
80
+
81
+ # Recommendations
82
+ 'TreatmentRecommender',
83
+ 'TreatmentRecommendation',
84
+ 'DISEASE_DATABASE',
85
+
86
+ # Pipeline
87
+ 'CropDoctorPipeline',
88
+ 'DiagnosticResult',
89
+ 'quick_diagnose',
90
+
91
+ # Visualization
92
+ 'create_diagnostic_visualization',
93
+ 'create_mask_overlay',
94
+ 'create_severity_heatmap',
95
+ 'create_comparison_view',
96
+ 'create_treatment_card',
97
+ 'save_visualization',
98
+ ]
src/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (2.06 kB). View file
 
src/__pycache__/gradio_demo.cpython-313.pyc ADDED
Binary file (13 kB). View file
 
src/__pycache__/leaf_segmenter.cpython-313.pyc ADDED
Binary file (13.6 kB). View file
 
src/__pycache__/pipeline.cpython-313.pyc ADDED
Binary file (25.8 kB). View file
 
src/__pycache__/sam3_segmentation.cpython-313.pyc ADDED
Binary file (35.1 kB). View file
 
src/__pycache__/severity_classifier.cpython-313.pyc ADDED
Binary file (24.2 kB). View file
 
src/__pycache__/treatment_recommender.cpython-313.pyc ADDED
Binary file (17 kB). View file
 
src/__pycache__/visualization.cpython-313.pyc ADDED
Binary file (18.3 kB). View file
 
src/gradio_demo.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Gradio Demo for CropDoctor-Semantic
4
+ ====================================
5
+
6
+ Interactive web interface for plant disease diagnosis.
7
+
8
+ Usage:
9
+ python src/gradio_demo.py
10
+
11
+ Then open http://localhost:7860 in your browser.
12
+ """
13
+
14
+ import gradio as gr
15
+ import numpy as np
16
+ from PIL import Image
17
+ from pathlib import Path
18
+ from datetime import datetime
19
+ import sys
20
+
21
+ # Add src to path
22
+ sys.path.insert(0, str(Path(__file__).parent.parent))
23
+
24
+ from src.pipeline import CropDoctorPipeline, DiagnosticResult
25
+ from src.visualization import create_mask_overlay, SEVERITY_COLORS
26
+ from src.treatment_recommender import TreatmentRecommender
27
+
28
+ # Initialize pipeline with mock for demo
29
+ # Set use_mock_sam3=False when you have SAM 3 installed
30
+ pipeline = None
31
+
32
+ # Output directory for saving results
33
+ OUTPUT_DIR = Path("output")
34
+ OUTPUT_DIR.mkdir(exist_ok=True)
35
+
36
+
37
+ def get_pipeline(use_leaf_segmentation: bool = False):
38
+ """Lazy load pipeline."""
39
+ global pipeline
40
+ if pipeline is None:
41
+ pipeline = CropDoctorPipeline(
42
+ classifier_checkpoint="models/severity_classifier/best.pt",
43
+ use_mock_sam3=False,
44
+ use_rfdetr=True,
45
+ rfdetr_checkpoint="models/rfdetr/checkpoint_best_total.pth",
46
+ use_llm=False,
47
+ use_leaf_segmentation=use_leaf_segmentation,
48
+ sam2_checkpoint="models/sam2/sam2.1_hiera_small.pt"
49
+ )
50
+ # Update leaf segmentation setting
51
+ pipeline.use_leaf_segmentation = use_leaf_segmentation
52
+ return pipeline
53
+
54
+
55
+ def diagnose_image(
56
+ image: np.ndarray,
57
+ plant_species: str,
58
+ analysis_profile: str,
59
+ use_leaf_segmentation: bool = True
60
+ ) -> tuple:
61
+ """
62
+ Main diagnosis function for Gradio.
63
+
64
+ Args:
65
+ image: Input image from Gradio
66
+ plant_species: Selected plant species
67
+ analysis_profile: Analysis profile to use
68
+ use_leaf_segmentation: Whether to segment leaf before detection
69
+
70
+ Returns:
71
+ Tuple of (annotated_image, diagnosis_text, treatment_text, severity_label)
72
+ """
73
+ if image is None:
74
+ return None, "Please upload an image", "", "unknown"
75
+
76
+ # Generate unique timestamp for this request
77
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
78
+ request_dir = OUTPUT_DIR / timestamp
79
+ request_dir.mkdir(exist_ok=True)
80
+
81
+ # Save original image
82
+ original_pil = Image.fromarray(image)
83
+ original_pil.save(request_dir / "original.png")
84
+
85
+ # Get pipeline with leaf segmentation option
86
+ pipe = get_pipeline(use_leaf_segmentation=use_leaf_segmentation)
87
+
88
+ # Get segmented leaf image if enabled
89
+ segmented_image = None
90
+ if use_leaf_segmentation:
91
+ segmented_pil, leaf_mask = pipe.leaf_segmenter.auto_segment_leaf(
92
+ original_pil, return_mask=True
93
+ )
94
+ segmented_image = np.array(segmented_pil)
95
+ # Save segmented leaf (no annotations yet)
96
+ segmented_pil.save(request_dir / "segmented_leaf.png")
97
+
98
+ # Run diagnosis
99
+ result = pipe.diagnose(
100
+ image,
101
+ plant_species=plant_species if plant_species != "Auto-detect" else None,
102
+ analysis_profile=analysis_profile.lower().replace(" ", "_"),
103
+ return_masks=True
104
+ )
105
+
106
+ # Create annotated images
107
+ # 1. Annotated original
108
+ if result.segmentation_masks is not None and len(result.segmentation_masks) > 0:
109
+ annotated_original = create_mask_overlay(
110
+ image,
111
+ result.segmentation_masks,
112
+ alpha=0.5
113
+ )
114
+ Image.fromarray(annotated_original).save(request_dir / "annotated_original.png")
115
+
116
+ # 2. Annotated segmented (if leaf segmentation enabled)
117
+ if segmented_image is not None:
118
+ annotated_segmented = create_mask_overlay(
119
+ segmented_image,
120
+ result.segmentation_masks,
121
+ alpha=0.5
122
+ )
123
+ Image.fromarray(annotated_segmented).save(request_dir / "annotated_segmented.png")
124
+ else:
125
+ annotated_segmented = annotated_original
126
+ else:
127
+ annotated_original = image
128
+ annotated_segmented = segmented_image if segmented_image is not None else image
129
+ Image.fromarray(annotated_original).save(request_dir / "annotated_original.png")
130
+ if segmented_image is not None:
131
+ Image.fromarray(segmented_image).save(request_dir / "annotated_segmented.png")
132
+
133
+ # Format diagnosis text
134
+ diagnosis_text = format_diagnosis(result)
135
+
136
+ # Format treatment text
137
+ treatment_text = format_treatment(result)
138
+
139
+ # Severity for color coding
140
+ severity = result.severity_label.lower()
141
+
142
+ # Add output path info to diagnosis
143
+ diagnosis_text += f"\n\n---\n*Output saved to: `{request_dir}`*"
144
+
145
+ # Return the segmented+annotated image (or original annotated if no segmentation)
146
+ display_image = annotated_segmented if use_leaf_segmentation else annotated_original
147
+
148
+ return display_image, diagnosis_text, treatment_text, severity
149
+
150
+
151
+ def format_diagnosis(result: DiagnosticResult) -> str:
152
+ """Format diagnosis results as markdown."""
153
+
154
+ severity_emoji = {
155
+ "healthy": "🟒",
156
+ "mild": "🟑",
157
+ "moderate": "🟠",
158
+ "severe": "πŸ”΄"
159
+ }
160
+
161
+ emoji = severity_emoji.get(result.severity_label.lower(), "βšͺ")
162
+
163
+ text = f"""
164
+ ## πŸ“‹ Diagnostic Results
165
+
166
+ ### Disease Identification
167
+ - **Disease**: {result.disease_name}
168
+ - **Type**: {result.disease_type.capitalize()}
169
+ - **Confidence**: {result.disease_confidence:.0%}
170
+
171
+ ### Severity Assessment
172
+ - **Level**: {emoji} **{result.severity_label.upper()}** (Level {result.severity_level}/3)
173
+ - **Affected Area**: {result.affected_area_percent:.1f}%
174
+ - **Confidence**: {result.severity_confidence:.0%}
175
+
176
+ ### Detected Symptoms
177
+ """
178
+
179
+ if result.detected_symptoms:
180
+ for symptom in result.detected_symptoms:
181
+ text += f"- {symptom}\n"
182
+ else:
183
+ text += "- No disease symptoms detected\n"
184
+
185
+ text += f"""
186
+ ### Urgency
187
+ **{result.urgency.upper()}** - {"Immediate action required!" if result.urgency == "critical" else "Monitor and treat as recommended."}
188
+ """
189
+
190
+ return text
191
+
192
+
193
+ def format_treatment(result: DiagnosticResult) -> str:
194
+ """Format treatment recommendations as markdown."""
195
+
196
+ text = f"""
197
+ ## 🌿 Treatment Recommendations
198
+
199
+ ### Organic Treatments
200
+ """
201
+
202
+ for i, treatment in enumerate(result.organic_treatments[:5], 1):
203
+ text += f"{i}. {treatment}\n"
204
+
205
+ if result.chemical_treatments:
206
+ text += "\n### Chemical Treatments\n"
207
+ for i, treatment in enumerate(result.chemical_treatments[:3], 1):
208
+ text += f"{i}. {treatment}\n"
209
+
210
+ text += "\n### Preventive Measures\n"
211
+ for i, measure in enumerate(result.preventive_measures[:5], 1):
212
+ text += f"{i}. {measure}\n"
213
+
214
+ text += f"""
215
+ ### Application Timing
216
+ {result.treatment_timing}
217
+
218
+ ---
219
+ *Note: These recommendations are AI-generated. Always consult with a local agricultural expert for critical decisions.*
220
+ """
221
+
222
+ return text
223
+
224
+
225
+ def create_demo():
226
+ """Create Gradio interface."""
227
+
228
+ # Custom CSS
229
+ css = """
230
+ .severity-healthy { background-color: #2ECC71 !important; }
231
+ .severity-mild { background-color: #F1C40F !important; }
232
+ .severity-moderate { background-color: #E67E22 !important; }
233
+ .severity-severe { background-color: #E74C3C !important; }
234
+
235
+ .diagnosis-panel {
236
+ border-left: 4px solid #3498DB;
237
+ padding-left: 1rem;
238
+ }
239
+
240
+ .treatment-panel {
241
+ border-left: 4px solid #27AE60;
242
+ padding-left: 1rem;
243
+ }
244
+ """
245
+
246
+ # Plant species options
247
+ plant_species = [
248
+ "Auto-detect",
249
+ "Apple",
250
+ "Grape",
251
+ "Tomato",
252
+ "Corn/Maize",
253
+ "Potato",
254
+ "Pepper",
255
+ "Strawberry",
256
+ "Cherry",
257
+ "Peach",
258
+ "Citrus",
259
+ "Rice",
260
+ "Wheat",
261
+ "Other"
262
+ ]
263
+
264
+ # Analysis profiles
265
+ profiles = [
266
+ "Quick Scan",
267
+ "Standard",
268
+ "Comprehensive",
269
+ "Pest Focused"
270
+ ]
271
+
272
+ # Example images (you can add real examples here)
273
+ examples = [
274
+ ["examples/healthy_tomato.jpg", "Tomato", "Standard"],
275
+ ["examples/leaf_spot.jpg", "Apple", "Comprehensive"],
276
+ ["examples/powdery_mildew.jpg", "Grape", "Standard"],
277
+ ]
278
+
279
+ with gr.Blocks(css=css, title="CropDoctor-Semantic") as demo:
280
+
281
+ gr.Markdown("""
282
+ # 🌱 CropDoctor-Semantic
283
+ ### AI-Powered Plant Disease Diagnosis with SAM 3
284
+
285
+ Upload an image of a plant leaf to get instant disease diagnosis and treatment recommendations.
286
+
287
+ ---
288
+ """)
289
+
290
+ with gr.Row():
291
+ # Left column - Input
292
+ with gr.Column(scale=1):
293
+ input_image = gr.Image(
294
+ label="πŸ“· Upload Plant Image",
295
+ type="numpy",
296
+ height=400
297
+ )
298
+
299
+ with gr.Row():
300
+ species_dropdown = gr.Dropdown(
301
+ choices=plant_species,
302
+ value="Auto-detect",
303
+ label="🌿 Plant Species"
304
+ )
305
+ profile_dropdown = gr.Dropdown(
306
+ choices=profiles,
307
+ value="Standard",
308
+ label="πŸ” Analysis Profile"
309
+ )
310
+
311
+ leaf_seg_checkbox = gr.Checkbox(
312
+ value=True,
313
+ label="πŸƒ Isoler la feuille (SAM2)",
314
+ info="Segmente la feuille avant la detection pour reduire les faux positifs"
315
+ )
316
+
317
+ diagnose_btn = gr.Button(
318
+ "πŸ”¬ Analyze Plant",
319
+ variant="primary",
320
+ size="lg"
321
+ )
322
+
323
+ # Right column - Output
324
+ with gr.Column(scale=1):
325
+ output_image = gr.Image(
326
+ label="🎯 Detection Results",
327
+ type="numpy",
328
+ height=400
329
+ )
330
+
331
+ severity_label = gr.Textbox(
332
+ label="Severity",
333
+ visible=False
334
+ )
335
+
336
+ with gr.Row():
337
+ # Diagnosis panel
338
+ with gr.Column(scale=1, elem_classes=["diagnosis-panel"]):
339
+ diagnosis_output = gr.Markdown(
340
+ label="Diagnosis",
341
+ value="Upload an image and click 'Analyze Plant' to see results."
342
+ )
343
+
344
+ # Treatment panel
345
+ with gr.Column(scale=1, elem_classes=["treatment-panel"]):
346
+ treatment_output = gr.Markdown(
347
+ label="Treatment",
348
+ value="Treatment recommendations will appear here."
349
+ )
350
+
351
+ # Examples section
352
+ gr.Markdown("### πŸ“š Example Images")
353
+ gr.Markdown("*Click an example to try it out (examples require actual image files)*")
354
+
355
+ # Event handlers
356
+ diagnose_btn.click(
357
+ fn=diagnose_image,
358
+ inputs=[input_image, species_dropdown, profile_dropdown, leaf_seg_checkbox],
359
+ outputs=[output_image, diagnosis_output, treatment_output, severity_label]
360
+ )
361
+
362
+ # Footer
363
+ gr.Markdown("""
364
+ ---
365
+ ### About
366
+
367
+ **CropDoctor-Semantic** uses cutting-edge AI technology:
368
+ - **SAM 3** (Segment Anything Model 3) for concept-based segmentation
369
+ - **Deep Learning** for severity classification
370
+ - **Claude AI** for intelligent treatment recommendations
371
+
372
+ ⚠️ *This is a demo version. For production use, install SAM 3 and configure the Anthropic API.*
373
+
374
+ πŸ“Š **Impact**: Plant diseases cause $220 billion in agricultural losses annually.
375
+ Early detection can significantly reduce crop losses.
376
+
377
+ ---
378
+ *Built with ❀️ for sustainable agriculture*
379
+ """)
380
+
381
+ return demo
382
+
383
+
384
+ if __name__ == "__main__":
385
+ demo = create_demo()
386
+ demo.launch(
387
+ share=False, # Set to True for public link
388
+ server_name="0.0.0.0",
389
+ server_port=7860,
390
+ show_error=True
391
+ )
src/leaf_segmenter.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Leaf Segmentation using SAM2.
3
+
4
+ This module provides leaf segmentation functionality to isolate leaves
5
+ from backgrounds before disease detection.
6
+ """
7
+
8
+ import numpy as np
9
+ from PIL import Image
10
+ from typing import Optional, Tuple
11
+ import torch
12
+
13
+
14
+ class SAM2LeafSegmenter:
15
+ """
16
+ Segments leaves from images using SAM2 (Segment Anything Model 2).
17
+
18
+ This is used as a preprocessing step to:
19
+ 1. Isolate the leaf from the background
20
+ 2. Create a white background image with just the leaf
21
+ 3. Reduce false positives in disease detection
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ checkpoint_path: str = "models/sam2/sam2.1_hiera_small.pt",
27
+ config_file: str = "configs/sam2.1/sam2.1_hiera_s.yaml",
28
+ device: Optional[str] = None
29
+ ):
30
+ """
31
+ Initialize SAM2 leaf segmenter.
32
+
33
+ Args:
34
+ checkpoint_path: Path to SAM2 checkpoint
35
+ config_file: SAM2 config file name
36
+ device: Device to use ('cuda', 'mps', 'cpu'). Auto-detected if None.
37
+ """
38
+ self.checkpoint_path = checkpoint_path
39
+ self.config_file = config_file
40
+
41
+ if device is None:
42
+ if torch.cuda.is_available():
43
+ self.device = 'cuda'
44
+ elif torch.backends.mps.is_available():
45
+ self.device = 'mps'
46
+ else:
47
+ self.device = 'cpu'
48
+ else:
49
+ self.device = device
50
+
51
+ self.model = None
52
+ self.predictor = None
53
+
54
+ def load_model(self):
55
+ """Load SAM2 model."""
56
+ if self.model is not None:
57
+ return
58
+
59
+ from sam2.build_sam import build_sam2
60
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
61
+
62
+ print(f"Loading SAM2 model on {self.device}...")
63
+ self.model = build_sam2(
64
+ config_file=self.config_file,
65
+ ckpt_path=self.checkpoint_path,
66
+ device=self.device
67
+ )
68
+ self.predictor = SAM2ImagePredictor(self.model)
69
+ print("SAM2 model loaded.")
70
+
71
+ def segment_leaf(
72
+ self,
73
+ image: Image.Image,
74
+ point: Optional[Tuple[int, int]] = None,
75
+ return_mask: bool = False
76
+ ) -> Image.Image | Tuple[Image.Image, np.ndarray]:
77
+ """
78
+ Segment the leaf from the image.
79
+
80
+ Args:
81
+ image: PIL Image to segment
82
+ point: (x, y) point to indicate the leaf. If None, uses image center.
83
+ return_mask: If True, also returns the binary mask
84
+
85
+ Returns:
86
+ Image with leaf on white background (and mask if return_mask=True)
87
+ """
88
+ self.load_model()
89
+
90
+ # Convert to numpy array
91
+ image_np = np.array(image.convert('RGB'))
92
+ h, w = image_np.shape[:2]
93
+
94
+ # Use center point if not specified
95
+ if point is None:
96
+ point = (w // 2, h // 2)
97
+
98
+ # Set image for predictor
99
+ self.predictor.set_image(image_np)
100
+
101
+ # Predict mask using point prompt
102
+ point_coords = np.array([[point[0], point[1]]])
103
+ point_labels = np.array([1]) # 1 = foreground
104
+
105
+ masks, scores, _ = self.predictor.predict(
106
+ point_coords=point_coords,
107
+ point_labels=point_labels,
108
+ multimask_output=True
109
+ )
110
+
111
+ # Select best mask (highest score)
112
+ best_idx = np.argmax(scores)
113
+ mask = masks[best_idx].astype(bool)
114
+
115
+ # Create white background image
116
+ result = np.ones_like(image_np) * 255 # White background
117
+ result[mask] = image_np[mask] # Copy leaf pixels
118
+
119
+ result_image = Image.fromarray(result.astype(np.uint8))
120
+
121
+ if return_mask:
122
+ return result_image, mask
123
+ return result_image
124
+
125
+ def segment_leaf_with_bbox(
126
+ self,
127
+ image: Image.Image,
128
+ bbox: Optional[Tuple[int, int, int, int]] = None,
129
+ return_mask: bool = False
130
+ ) -> Image.Image | Tuple[Image.Image, np.ndarray]:
131
+ """
132
+ Segment the leaf using a bounding box prompt.
133
+
134
+ Args:
135
+ image: PIL Image to segment
136
+ bbox: (x1, y1, x2, y2) bounding box. If None, uses full image.
137
+ return_mask: If True, also returns the binary mask
138
+
139
+ Returns:
140
+ Image with leaf on white background (and mask if return_mask=True)
141
+ """
142
+ self.load_model()
143
+
144
+ # Convert to numpy array
145
+ image_np = np.array(image.convert('RGB'))
146
+ h, w = image_np.shape[:2]
147
+
148
+ # Use full image bbox if not specified
149
+ if bbox is None:
150
+ # Use slightly inset bbox to focus on leaf
151
+ margin = min(w, h) // 20
152
+ bbox = (margin, margin, w - margin, h - margin)
153
+
154
+ # Set image for predictor
155
+ self.predictor.set_image(image_np)
156
+
157
+ # Predict mask using box prompt
158
+ box = np.array([bbox])
159
+
160
+ masks, scores, _ = self.predictor.predict(
161
+ box=box,
162
+ multimask_output=True
163
+ )
164
+
165
+ # Select best mask (highest score)
166
+ best_idx = np.argmax(scores)
167
+ mask = masks[best_idx].astype(bool)
168
+
169
+ # Create white background image
170
+ result = np.ones_like(image_np) * 255 # White background
171
+ result[mask] = image_np[mask] # Copy leaf pixels
172
+
173
+ result_image = Image.fromarray(result.astype(np.uint8))
174
+
175
+ if return_mask:
176
+ return result_image, mask
177
+ return result_image
178
+
179
+ def auto_segment_leaf(
180
+ self,
181
+ image: Image.Image,
182
+ return_mask: bool = False
183
+ ) -> Image.Image | Tuple[Image.Image, np.ndarray]:
184
+ """
185
+ Automatically segment the main leaf/plant from the image.
186
+
187
+ Uses multiple strategies to find the best segmentation:
188
+ 1. Center point
189
+ 2. Multiple points in a grid
190
+ 3. Green color detection for better point selection
191
+ 4. Selects the largest coherent mask
192
+
193
+ Args:
194
+ image: PIL Image to segment
195
+ return_mask: If True, also returns the binary mask
196
+
197
+ Returns:
198
+ Image with leaf on white background (and mask if return_mask=True)
199
+ """
200
+ self.load_model()
201
+
202
+ # Convert to numpy array
203
+ image_np = np.array(image.convert('RGB'))
204
+ h, w = image_np.shape[:2]
205
+
206
+ # Set image for predictor
207
+ self.predictor.set_image(image_np)
208
+
209
+ # Try to find a good point on the leaf using green color detection
210
+ # Convert to HSV for better color detection
211
+ from PIL import ImageFilter
212
+ import colorsys
213
+
214
+ # Simple green detection: look for pixels with green hue
215
+ green_mask = self._detect_green_regions(image_np)
216
+
217
+ # Find centroid of green regions, fallback to image center
218
+ if green_mask.sum() > 100: # At least some green pixels
219
+ y_coords, x_coords = np.where(green_mask)
220
+ center_x = int(np.median(x_coords))
221
+ center_y = int(np.median(y_coords))
222
+ else:
223
+ center_x, center_y = w // 2, h // 2
224
+
225
+ # Try multiple points for robustness
226
+ points_to_try = [
227
+ (center_x, center_y), # Green centroid or center
228
+ (w // 2, h // 2), # Image center
229
+ (w // 3, h // 2), # Left third
230
+ (2 * w // 3, h // 2), # Right third
231
+ ]
232
+
233
+ best_mask = None
234
+ best_score = -1
235
+
236
+ for px, py in points_to_try:
237
+ point = np.array([[px, py]])
238
+ label = np.array([1])
239
+
240
+ masks, scores, _ = self.predictor.predict(
241
+ point_coords=point,
242
+ point_labels=label,
243
+ multimask_output=True
244
+ )
245
+
246
+ for mask, score in zip(masks, scores):
247
+ # Ensure mask is boolean for indexing
248
+ mask = mask.astype(bool)
249
+
250
+ # Calculate mask coverage
251
+ coverage = mask.sum() / (h * w)
252
+
253
+ # Prefer masks that cover 5-95% of image (more flexible range)
254
+ if 0.05 < coverage < 0.95:
255
+ # Check if mask contains green (likely a leaf)
256
+ green_in_mask = green_mask[mask].sum() / max(mask.sum(), 1)
257
+
258
+ # Bonus for being closer to 30-70% coverage
259
+ coverage_score = 1 - abs(coverage - 0.5)
260
+
261
+ # Combined score: SAM confidence + coverage + greenness
262
+ combined_score = score * 0.5 + coverage_score * 0.2 + green_in_mask * 0.3
263
+
264
+ if combined_score > best_score:
265
+ best_score = combined_score
266
+ best_mask = mask
267
+
268
+ # Fallback to highest score mask from center point
269
+ if best_mask is None:
270
+ center_point = np.array([[w // 2, h // 2]])
271
+ center_label = np.array([1])
272
+ masks, scores, _ = self.predictor.predict(
273
+ point_coords=center_point,
274
+ point_labels=center_label,
275
+ multimask_output=True
276
+ )
277
+ best_idx = np.argmax(scores)
278
+ best_mask = masks[best_idx]
279
+
280
+ # Ensure mask is boolean
281
+ best_mask = best_mask.astype(bool)
282
+
283
+ # Create white background image
284
+ result = np.ones_like(image_np) * 255 # White background
285
+ result[best_mask] = image_np[best_mask] # Copy leaf pixels
286
+
287
+ result_image = Image.fromarray(result.astype(np.uint8))
288
+
289
+ if return_mask:
290
+ return result_image, best_mask
291
+ return result_image
292
+
293
+ def _detect_green_regions(self, image_np: np.ndarray) -> np.ndarray:
294
+ """Detect green regions in image (likely leaf areas)."""
295
+ # Convert RGB to HSV for better green detection
296
+ r, g, b = image_np[:,:,0], image_np[:,:,1], image_np[:,:,2]
297
+
298
+ # Green typically has: g > r, g > b, and reasonable brightness
299
+ green_mask = (
300
+ (g > r * 0.9) & # Green channel dominant over red
301
+ (g > b * 0.9) & # Green channel dominant over blue
302
+ (g > 40) & # Not too dark
303
+ (g < 250) # Not too bright (white)
304
+ )
305
+
306
+ # Also detect yellow-green (common in leaves)
307
+ yellow_green = (
308
+ (g > 50) &
309
+ (r > 50) &
310
+ (b < r * 0.8) & # Blue much less than red
311
+ (abs(g.astype(int) - r.astype(int)) < 80) # R and G similar
312
+ )
313
+
314
+ return green_mask | yellow_green
315
+
316
+ def refine_boxes_to_masks(
317
+ self,
318
+ image: Image.Image,
319
+ boxes: np.ndarray,
320
+ return_scores: bool = False
321
+ ) -> np.ndarray | Tuple[np.ndarray, np.ndarray]:
322
+ """
323
+ Refine bounding boxes into precise segmentation masks using SAM2.
324
+
325
+ This is used to convert RF-DETR detection boxes into proper
326
+ segmentation masks for disease regions.
327
+
328
+ Args:
329
+ image: PIL Image
330
+ boxes: Array of bounding boxes [N, 4] in xyxy format
331
+ return_scores: If True, also returns confidence scores
332
+
333
+ Returns:
334
+ Array of masks [N, H, W] (and scores if return_scores=True)
335
+ """
336
+ self.load_model()
337
+
338
+ # Convert to numpy array
339
+ image_np = np.array(image.convert('RGB'))
340
+ h, w = image_np.shape[:2]
341
+
342
+ if len(boxes) == 0:
343
+ empty_masks = np.zeros((0, h, w), dtype=bool)
344
+ if return_scores:
345
+ return empty_masks, np.zeros((0,), dtype=np.float32)
346
+ return empty_masks
347
+
348
+ # Set image for predictor
349
+ self.predictor.set_image(image_np)
350
+
351
+ masks_list = []
352
+ scores_list = []
353
+
354
+ for box in boxes:
355
+ # Use box prompt for SAM2
356
+ box_np = np.array([box])
357
+
358
+ masks, scores, _ = self.predictor.predict(
359
+ box=box_np,
360
+ multimask_output=True
361
+ )
362
+
363
+ # Select best mask (highest score)
364
+ best_idx = np.argmax(scores)
365
+ best_mask = masks[best_idx].astype(bool)
366
+ best_score = scores[best_idx]
367
+
368
+ masks_list.append(best_mask)
369
+ scores_list.append(best_score)
370
+
371
+ result_masks = np.stack(masks_list, axis=0) if masks_list else np.zeros((0, h, w), dtype=bool)
372
+ result_scores = np.array(scores_list, dtype=np.float32)
373
+
374
+ if return_scores:
375
+ return result_masks, result_scores
376
+ return result_masks
377
+
378
+
379
+ # Convenience function
380
+ def create_leaf_segmenter(
381
+ checkpoint_path: str = "models/sam2/sam2.1_hiera_small.pt",
382
+ device: Optional[str] = None
383
+ ) -> SAM2LeafSegmenter:
384
+ """Create a SAM2 leaf segmenter instance."""
385
+ return SAM2LeafSegmenter(
386
+ checkpoint_path=checkpoint_path,
387
+ device=device
388
+ )
src/pipeline.py ADDED
@@ -0,0 +1,631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CropDoctor Pipeline Module
3
+ ==========================
4
+
5
+ This module integrates all components into a unified diagnostic pipeline:
6
+ 1. SAM 3 Segmentation - Detect disease regions
7
+ 2. Severity Classification - Assess severity level
8
+ 3. Treatment Recommendation - Generate actionable advice
9
+
10
+ Provides both single-image and batch processing capabilities.
11
+ """
12
+
13
+ import torch
14
+ import numpy as np
15
+ from PIL import Image
16
+ from pathlib import Path
17
+ from typing import Dict, List, Optional, Union, Tuple
18
+ from dataclasses import dataclass, asdict
19
+ import json
20
+ import csv
21
+ from datetime import datetime
22
+ import logging
23
+
24
+ from .sam3_segmentation import SAM3Segmenter, create_segmenter, SegmentationResult
25
+ from .severity_classifier import SeverityClassifier, SeverityPrediction, SEVERITY_LABELS
26
+ from .treatment_recommender import TreatmentRecommender, TreatmentRecommendation
27
+ from .leaf_segmenter import SAM2LeafSegmenter
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ @dataclass
33
+ class DiagnosticResult:
34
+ """Complete diagnostic result for an image."""
35
+ # Image info
36
+ image_path: str
37
+ timestamp: str
38
+
39
+ # Segmentation results
40
+ num_regions_detected: int
41
+ affected_area_percent: float
42
+ detected_symptoms: List[str]
43
+
44
+ # Severity assessment
45
+ severity_level: int
46
+ severity_label: str
47
+ severity_confidence: float
48
+
49
+ # Treatment recommendations
50
+ disease_name: str
51
+ disease_type: str
52
+ disease_confidence: float
53
+ organic_treatments: List[str]
54
+ chemical_treatments: List[str]
55
+ preventive_measures: List[str]
56
+ treatment_timing: str
57
+ urgency: str
58
+
59
+ # Raw data for further analysis
60
+ segmentation_masks: Optional[np.ndarray] = None
61
+ segmentation_scores: Optional[np.ndarray] = None
62
+
63
+
64
+ class CropDoctorPipeline:
65
+ """
66
+ End-to-end pipeline for plant disease diagnosis.
67
+
68
+ Integrates SAM 3 segmentation, severity classification, and
69
+ LLM-based treatment recommendations into a single workflow.
70
+
71
+ Example:
72
+ >>> pipeline = CropDoctorPipeline()
73
+ >>> result = pipeline.diagnose("path/to/leaf.jpg")
74
+ >>> print(f"Disease: {result.disease_name}")
75
+ >>> print(f"Severity: {result.severity_label}")
76
+ >>> print(f"Treatment: {result.organic_treatments[0]}")
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ sam3_checkpoint: str = "models/sam3/sam3.pt",
82
+ sam3_config: str = "configs/sam3_config.yaml",
83
+ classifier_checkpoint: Optional[str] = None,
84
+ use_llm: bool = True,
85
+ anthropic_api_key: Optional[str] = None,
86
+ device: Optional[str] = None,
87
+ use_mock_sam3: bool = False,
88
+ use_rfdetr: bool = False,
89
+ rfdetr_checkpoint: str = "models/rfdetr/checkpoint_best_total.pth",
90
+ rfdetr_model_size: str = "medium",
91
+ use_leaf_segmentation: bool = False,
92
+ sam2_checkpoint: str = "models/sam2/sam2.1_hiera_small.pt"
93
+ ):
94
+ """
95
+ Initialize the CropDoctor pipeline.
96
+
97
+ Args:
98
+ sam3_checkpoint: Path to SAM 3 model checkpoint
99
+ sam3_config: Path to SAM 3 configuration
100
+ classifier_checkpoint: Path to severity classifier checkpoint
101
+ use_llm: Whether to use Claude API for recommendations
102
+ anthropic_api_key: Optional API key for Claude
103
+ device: Device to use (auto-detected if None)
104
+ use_mock_sam3: Use mock SAM 3 for testing without model
105
+ use_rfdetr: Use RF-DETR for detection (recommended)
106
+ rfdetr_checkpoint: Path to trained RF-DETR model
107
+ rfdetr_model_size: RF-DETR model size (nano, small, medium, base)
108
+ use_leaf_segmentation: Use SAM2 to segment leaf before detection
109
+ sam2_checkpoint: Path to SAM2 checkpoint for leaf segmentation
110
+ """
111
+ # Set device
112
+ if device is None:
113
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
114
+ else:
115
+ self.device = device
116
+
117
+ logger.info(f"Initializing CropDoctor Pipeline on {self.device}")
118
+
119
+ # Initialize components (lazy loading)
120
+ self._segmenter = None
121
+ self._classifier = None
122
+ self._recommender = None
123
+ self._leaf_segmenter = None
124
+
125
+ # Store config
126
+ self.sam3_checkpoint = sam3_checkpoint
127
+ self.sam3_config = sam3_config
128
+ self.classifier_checkpoint = classifier_checkpoint
129
+ self.use_llm = use_llm
130
+ self.anthropic_api_key = anthropic_api_key
131
+ self.use_mock_sam3 = use_mock_sam3
132
+ self.use_rfdetr = use_rfdetr
133
+ self.rfdetr_checkpoint = rfdetr_checkpoint
134
+ self.rfdetr_model_size = rfdetr_model_size
135
+ self.use_leaf_segmentation = use_leaf_segmentation
136
+ self.sam2_checkpoint = sam2_checkpoint
137
+
138
+ # Default prompts for disease detection
139
+ self.disease_prompts = [
140
+ "diseased plant tissue",
141
+ "leaf with brown spots",
142
+ "powdery mildew",
143
+ "yellowing leaves",
144
+ "pest damage",
145
+ "wilted tissue",
146
+ "healthy green leaf"
147
+ ]
148
+
149
+ @property
150
+ def segmenter(self) -> SAM3Segmenter:
151
+ """Lazy load segmenter (SAM 3, Mock, or RF-DETR)."""
152
+ if self._segmenter is None:
153
+ self._segmenter = create_segmenter(
154
+ self.sam3_checkpoint,
155
+ self.sam3_config,
156
+ use_mock=self.use_mock_sam3,
157
+ use_rfdetr=self.use_rfdetr,
158
+ rfdetr_checkpoint=self.rfdetr_checkpoint,
159
+ rfdetr_model_size=self.rfdetr_model_size
160
+ )
161
+ return self._segmenter
162
+
163
+ @property
164
+ def classifier(self) -> SeverityClassifier:
165
+ """Lazy load severity classifier."""
166
+ if self._classifier is None:
167
+ self._classifier = SeverityClassifier(
168
+ checkpoint_path=self.classifier_checkpoint,
169
+ device=self.device
170
+ )
171
+ return self._classifier
172
+
173
+ @property
174
+ def recommender(self) -> TreatmentRecommender:
175
+ """Lazy load treatment recommender."""
176
+ if self._recommender is None:
177
+ self._recommender = TreatmentRecommender(
178
+ api_key=self.anthropic_api_key,
179
+ use_llm=self.use_llm
180
+ )
181
+ return self._recommender
182
+
183
+ @property
184
+ def leaf_segmenter(self) -> SAM2LeafSegmenter:
185
+ """Lazy load SAM2 leaf segmenter."""
186
+ if self._leaf_segmenter is None:
187
+ self._leaf_segmenter = SAM2LeafSegmenter(
188
+ checkpoint_path=self.sam2_checkpoint,
189
+ device=self.device
190
+ )
191
+ return self._leaf_segmenter
192
+
193
+ def diagnose(
194
+ self,
195
+ image: Union[str, Path, Image.Image, np.ndarray],
196
+ plant_species: Optional[str] = None,
197
+ analysis_profile: str = "standard",
198
+ custom_prompts: Optional[List[str]] = None,
199
+ return_masks: bool = False
200
+ ) -> DiagnosticResult:
201
+ """
202
+ Perform complete diagnosis on an image.
203
+
204
+ Args:
205
+ image: Input image (path, PIL Image, or numpy array)
206
+ plant_species: Optional plant species for context
207
+ analysis_profile: SAM 3 analysis profile
208
+ custom_prompts: Optional custom prompts for segmentation
209
+ return_masks: Whether to include segmentation masks in result
210
+
211
+ Returns:
212
+ DiagnosticResult with complete diagnosis
213
+ """
214
+ timestamp = datetime.now().isoformat()
215
+ image_path = str(image) if isinstance(image, (str, Path)) else "in_memory"
216
+
217
+ logger.info(f"Starting diagnosis for {image_path}")
218
+
219
+ # Load image if needed
220
+ if isinstance(image, (str, Path)):
221
+ pil_image = Image.open(image).convert("RGB")
222
+ elif isinstance(image, np.ndarray):
223
+ pil_image = Image.fromarray(image)
224
+ else:
225
+ pil_image = image
226
+
227
+ # Store original image for visualization
228
+ original_image = pil_image
229
+ leaf_mask = None
230
+
231
+ # Step 0 (optional): Leaf segmentation with SAM2
232
+ if self.use_leaf_segmentation:
233
+ logger.info("Step 0: Leaf segmentation with SAM2")
234
+ pil_image, leaf_mask = self.leaf_segmenter.auto_segment_leaf(
235
+ pil_image,
236
+ return_mask=True
237
+ )
238
+ logger.info("Leaf isolated from background")
239
+
240
+ # Step 1: Disease Detection (SAM3/RF-DETR)
241
+ logger.info("Step 1: Disease detection")
242
+ prompts = custom_prompts or self.disease_prompts
243
+ seg_result = self.segmenter.segment_with_concepts(pil_image, prompts)
244
+
245
+ # Step 1.5: Refine bounding boxes to proper masks using SAM2
246
+ # This converts RF-DETR rectangular boxes to precise segmentation masks
247
+ if self.use_rfdetr and len(seg_result.boxes) > 0:
248
+ logger.info("Step 1.5: Refining detection boxes with SAM2")
249
+ refined_masks = self.leaf_segmenter.refine_boxes_to_masks(
250
+ pil_image,
251
+ seg_result.boxes
252
+ )
253
+ # Replace rectangular masks with refined masks
254
+ seg_result = SegmentationResult(
255
+ masks=refined_masks,
256
+ boxes=seg_result.boxes,
257
+ scores=seg_result.scores,
258
+ prompts=seg_result.prompts,
259
+ prompt_indices=seg_result.prompt_indices
260
+ )
261
+
262
+ # Calculate affected area
263
+ area_stats = self.segmenter.calculate_affected_area(
264
+ seg_result,
265
+ healthy_prompt_idx=prompts.index("healthy green leaf") if "healthy green leaf" in prompts else None
266
+ )
267
+
268
+ # Get detected symptoms (prompts with detections)
269
+ detected_symptoms = []
270
+ for prompt_idx in np.unique(seg_result.prompt_indices):
271
+ if prompt_idx < len(prompts):
272
+ prompt = prompts[prompt_idx]
273
+ if prompt != "healthy green leaf":
274
+ detected_symptoms.append(prompt)
275
+
276
+ # Step 2: Severity Classification
277
+ logger.info("Step 2: Severity classification")
278
+
279
+ # Create combined disease mask for classification
280
+ if len(seg_result.masks) > 0:
281
+ # Combine all disease masks (excluding healthy)
282
+ disease_mask = np.zeros(seg_result.masks[0].shape, dtype=bool)
283
+ for i, mask in enumerate(seg_result.masks):
284
+ prompt_idx = seg_result.prompt_indices[i]
285
+ if prompt_idx < len(prompts) and prompts[prompt_idx] != "healthy green leaf":
286
+ disease_mask |= mask
287
+ else:
288
+ disease_mask = None
289
+
290
+ severity_result = self.classifier.classify(pil_image, mask=disease_mask)
291
+
292
+ # Override severity based on affected area if needed
293
+ if area_stats["total_affected_percent"] < 1:
294
+ severity_result = SeverityPrediction(
295
+ severity_level=0,
296
+ severity_label="healthy",
297
+ confidence=0.9,
298
+ probabilities={"healthy": 0.9, "mild": 0.05, "moderate": 0.03, "severe": 0.02},
299
+ affected_area_percent=area_stats["total_affected_percent"]
300
+ )
301
+ elif area_stats["total_affected_percent"] < 10 and severity_result.severity_level > 1:
302
+ severity_result.severity_level = 1
303
+ severity_result.severity_label = "mild"
304
+
305
+ # Step 3: Treatment Recommendations
306
+ logger.info("Step 3: Generating treatment recommendations")
307
+
308
+ if detected_symptoms and severity_result.severity_level > 0:
309
+ treatment_result = self.recommender.get_recommendation(
310
+ symptoms=detected_symptoms,
311
+ severity=severity_result.severity_label,
312
+ plant_species=plant_species,
313
+ affected_area_percent=area_stats["total_affected_percent"]
314
+ )
315
+ else:
316
+ # Healthy plant - no treatment needed
317
+ treatment_result = TreatmentRecommendation(
318
+ disease_name="No Disease Detected",
319
+ disease_type="healthy",
320
+ confidence=0.9,
321
+ symptoms_matched=[],
322
+ organic_treatments=["Continue regular care"],
323
+ chemical_treatments=[],
324
+ preventive_measures=[
325
+ "Maintain good air circulation",
326
+ "Water at soil level",
327
+ "Monitor regularly for early symptoms"
328
+ ],
329
+ timing="Regular monitoring recommended",
330
+ urgency="low",
331
+ additional_notes="Plant appears healthy. Continue preventive care."
332
+ )
333
+
334
+ # Compile final result
335
+ result = DiagnosticResult(
336
+ image_path=image_path,
337
+ timestamp=timestamp,
338
+ num_regions_detected=len(seg_result.masks),
339
+ affected_area_percent=area_stats["total_affected_percent"],
340
+ detected_symptoms=detected_symptoms,
341
+ severity_level=severity_result.severity_level,
342
+ severity_label=severity_result.severity_label,
343
+ severity_confidence=severity_result.confidence,
344
+ disease_name=treatment_result.disease_name,
345
+ disease_type=treatment_result.disease_type,
346
+ disease_confidence=treatment_result.confidence,
347
+ organic_treatments=treatment_result.organic_treatments,
348
+ chemical_treatments=treatment_result.chemical_treatments,
349
+ preventive_measures=treatment_result.preventive_measures,
350
+ treatment_timing=treatment_result.timing,
351
+ urgency=treatment_result.urgency,
352
+ segmentation_masks=seg_result.masks if return_masks else None,
353
+ segmentation_scores=seg_result.scores if return_masks else None
354
+ )
355
+
356
+ logger.info(f"Diagnosis complete: {result.disease_name} ({result.severity_label})")
357
+
358
+ return result
359
+
360
+ def batch_diagnose(
361
+ self,
362
+ image_folder: Union[str, Path],
363
+ output_dir: Optional[Union[str, Path]] = None,
364
+ plant_species: Optional[str] = None,
365
+ save_visualizations: bool = True,
366
+ file_extensions: List[str] = [".jpg", ".jpeg", ".png"]
367
+ ) -> List[DiagnosticResult]:
368
+ """
369
+ Process multiple images from a folder.
370
+
371
+ Args:
372
+ image_folder: Path to folder containing images
373
+ output_dir: Where to save results (optional)
374
+ plant_species: Plant species for all images
375
+ save_visualizations: Whether to save visualization images
376
+ file_extensions: Image file extensions to process
377
+
378
+ Returns:
379
+ List of DiagnosticResult for each image
380
+ """
381
+ image_folder = Path(image_folder)
382
+
383
+ if output_dir:
384
+ output_dir = Path(output_dir)
385
+ output_dir.mkdir(parents=True, exist_ok=True)
386
+
387
+ # Find all images
388
+ images = []
389
+ for ext in file_extensions:
390
+ images.extend(image_folder.glob(f"*{ext}"))
391
+ images.extend(image_folder.glob(f"*{ext.upper()}"))
392
+
393
+ logger.info(f"Found {len(images)} images to process")
394
+
395
+ results = []
396
+
397
+ for i, img_path in enumerate(images):
398
+ logger.info(f"Processing {i+1}/{len(images)}: {img_path.name}")
399
+
400
+ try:
401
+ result = self.diagnose(
402
+ img_path,
403
+ plant_species=plant_species,
404
+ return_masks=save_visualizations
405
+ )
406
+ results.append(result)
407
+
408
+ # Save visualization if requested
409
+ if save_visualizations and output_dir:
410
+ self._save_visualization(
411
+ img_path,
412
+ result,
413
+ output_dir / f"{img_path.stem}_diagnosis.png"
414
+ )
415
+
416
+ except Exception as e:
417
+ logger.error(f"Error processing {img_path}: {e}")
418
+ continue
419
+
420
+ logger.info(f"Batch processing complete: {len(results)}/{len(images)} successful")
421
+
422
+ return results
423
+
424
+ def _save_visualization(
425
+ self,
426
+ image_path: Path,
427
+ result: DiagnosticResult,
428
+ output_path: Path
429
+ ):
430
+ """Save diagnostic visualization."""
431
+ # Import visualization module
432
+ from .visualization import create_diagnostic_visualization
433
+
434
+ image = Image.open(image_path)
435
+
436
+ fig = create_diagnostic_visualization(
437
+ image,
438
+ result.segmentation_masks,
439
+ result.severity_label,
440
+ result.disease_name,
441
+ result.affected_area_percent
442
+ )
443
+
444
+ fig.savefig(output_path, dpi=150, bbox_inches='tight')
445
+ import matplotlib.pyplot as plt
446
+ plt.close(fig)
447
+
448
+ def export_report(
449
+ self,
450
+ results: List[DiagnosticResult],
451
+ output_path: Union[str, Path],
452
+ format: str = "csv"
453
+ ):
454
+ """
455
+ Export results to file.
456
+
457
+ Args:
458
+ results: List of diagnostic results
459
+ output_path: Output file path
460
+ format: Output format ("csv", "json")
461
+ """
462
+ output_path = Path(output_path)
463
+
464
+ if format == "csv":
465
+ self._export_csv(results, output_path)
466
+ elif format == "json":
467
+ self._export_json(results, output_path)
468
+ else:
469
+ raise ValueError(f"Unknown format: {format}")
470
+
471
+ logger.info(f"Report exported to {output_path}")
472
+
473
+ def _export_csv(self, results: List[DiagnosticResult], output_path: Path):
474
+ """Export to CSV."""
475
+ with open(output_path, 'w', newline='') as f:
476
+ if not results:
477
+ return
478
+
479
+ # Get fields (excluding numpy arrays)
480
+ fields = [
481
+ 'image_path', 'timestamp', 'num_regions_detected',
482
+ 'affected_area_percent', 'detected_symptoms',
483
+ 'severity_level', 'severity_label', 'severity_confidence',
484
+ 'disease_name', 'disease_type', 'disease_confidence',
485
+ 'organic_treatments', 'urgency'
486
+ ]
487
+
488
+ writer = csv.DictWriter(f, fieldnames=fields)
489
+ writer.writeheader()
490
+
491
+ for result in results:
492
+ row = {
493
+ 'image_path': result.image_path,
494
+ 'timestamp': result.timestamp,
495
+ 'num_regions_detected': result.num_regions_detected,
496
+ 'affected_area_percent': f"{result.affected_area_percent:.2f}",
497
+ 'detected_symptoms': '; '.join(result.detected_symptoms),
498
+ 'severity_level': result.severity_level,
499
+ 'severity_label': result.severity_label,
500
+ 'severity_confidence': f"{result.severity_confidence:.3f}",
501
+ 'disease_name': result.disease_name,
502
+ 'disease_type': result.disease_type,
503
+ 'disease_confidence': f"{result.disease_confidence:.3f}",
504
+ 'organic_treatments': '; '.join(result.organic_treatments[:3]),
505
+ 'urgency': result.urgency
506
+ }
507
+ writer.writerow(row)
508
+
509
+ def _export_json(self, results: List[DiagnosticResult], output_path: Path):
510
+ """Export to JSON."""
511
+ data = []
512
+ for result in results:
513
+ d = asdict(result)
514
+ # Remove numpy arrays
515
+ d.pop('segmentation_masks', None)
516
+ d.pop('segmentation_scores', None)
517
+ data.append(d)
518
+
519
+ with open(output_path, 'w') as f:
520
+ json.dump(data, f, indent=2, default=str)
521
+
522
+ def generate_summary_report(
523
+ self,
524
+ results: List[DiagnosticResult]
525
+ ) -> str:
526
+ """
527
+ Generate a summary report for batch results.
528
+
529
+ Args:
530
+ results: List of diagnostic results
531
+
532
+ Returns:
533
+ Formatted summary report string
534
+ """
535
+ if not results:
536
+ return "No results to summarize."
537
+
538
+ # Calculate statistics
539
+ total = len(results)
540
+ healthy = sum(1 for r in results if r.severity_level == 0)
541
+ mild = sum(1 for r in results if r.severity_level == 1)
542
+ moderate = sum(1 for r in results if r.severity_level == 2)
543
+ severe = sum(1 for r in results if r.severity_level == 3)
544
+
545
+ # Disease frequency
546
+ disease_counts = {}
547
+ for r in results:
548
+ disease_counts[r.disease_name] = disease_counts.get(r.disease_name, 0) + 1
549
+
550
+ # Average affected area
551
+ avg_affected = np.mean([r.affected_area_percent for r in results])
552
+
553
+ report = f"""
554
+ ╔══════════════════════════════════════════════════════════════╗
555
+ β•‘ BATCH DIAGNOSIS SUMMARY β•‘
556
+ β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
557
+
558
+ πŸ“Š OVERALL STATISTICS
559
+ Total Images Analyzed: {total}
560
+
561
+ Severity Distribution:
562
+ β”œβ”€β”€ 🟒 Healthy: {healthy} ({healthy/total*100:.1f}%)
563
+ β”œβ”€β”€ 🟑 Mild: {mild} ({mild/total*100:.1f}%)
564
+ β”œβ”€β”€ 🟠 Moderate: {moderate} ({moderate/total*100:.1f}%)
565
+ └── πŸ”΄ Severe: {severe} ({severe/total*100:.1f}%)
566
+
567
+ Average Affected Area: {avg_affected:.1f}%
568
+
569
+ 🦠 DISEASE FREQUENCY
570
+ """
571
+ for disease, count in sorted(disease_counts.items(), key=lambda x: -x[1]):
572
+ report += f" β€’ {disease}: {count} ({count/total*100:.1f}%)\n"
573
+
574
+ # Urgent cases
575
+ urgent = [r for r in results if r.urgency in ['high', 'critical']]
576
+ if urgent:
577
+ report += f"""
578
+ ⚠️ URGENT ATTENTION REQUIRED
579
+ {len(urgent)} images require immediate attention:
580
+ """
581
+ for r in urgent[:5]: # Show top 5
582
+ report += f" β€’ {Path(r.image_path).name}: {r.disease_name} ({r.urgency})\n"
583
+
584
+ report += """
585
+ ═══════════════════════════════════════════════════════════════
586
+ """
587
+ return report
588
+
589
+
590
+ def quick_diagnose(
591
+ image_path: str,
592
+ use_mock: bool = True
593
+ ) -> DiagnosticResult:
594
+ """
595
+ Quick diagnosis function for simple use cases.
596
+
597
+ Args:
598
+ image_path: Path to image
599
+ use_mock: Use mock models (for testing without SAM 3)
600
+
601
+ Returns:
602
+ DiagnosticResult
603
+ """
604
+ pipeline = CropDoctorPipeline(
605
+ use_mock_sam3=use_mock,
606
+ use_llm=False
607
+ )
608
+ return pipeline.diagnose(image_path)
609
+
610
+
611
+ if __name__ == "__main__":
612
+ # Test the pipeline with mock
613
+ print("Testing CropDoctor Pipeline...")
614
+
615
+ # Create test image
616
+ test_img = Image.new("RGB", (640, 480), color=(139, 69, 19))
617
+ test_img.save("/tmp/test_leaf.jpg")
618
+
619
+ # Run pipeline
620
+ pipeline = CropDoctorPipeline(use_mock_sam3=True, use_llm=False)
621
+ result = pipeline.diagnose("/tmp/test_leaf.jpg")
622
+
623
+ print(f"\nπŸ“‹ Diagnosis Results:")
624
+ print(f" Disease: {result.disease_name}")
625
+ print(f" Type: {result.disease_type}")
626
+ print(f" Severity: {result.severity_label} (Level {result.severity_level})")
627
+ print(f" Affected Area: {result.affected_area_percent:.1f}%")
628
+ print(f" Urgency: {result.urgency}")
629
+ print(f"\n🌿 Recommended Treatments:")
630
+ for t in result.organic_treatments[:3]:
631
+ print(f" β€’ {t}")
src/sam3_segmentation.py ADDED
@@ -0,0 +1,864 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SAM 3 Segmentation Module for CropDoctor-Semantic
3
+ ==================================================
4
+
5
+ This module provides the core segmentation functionality using Meta's SAM 3
6
+ for concept-based plant disease detection.
7
+
8
+ SAM 3 enables zero-shot segmentation using natural language prompts,
9
+ allowing detection of disease symptoms without task-specific training.
10
+ """
11
+
12
+ import torch
13
+ import numpy as np
14
+ from PIL import Image
15
+ from pathlib import Path
16
+ from typing import List, Dict, Tuple, Optional, Union
17
+ from dataclasses import dataclass
18
+ import yaml
19
+ import logging
20
+
21
+ # Configure logging
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ @dataclass
27
+ class SegmentationResult:
28
+ """Container for segmentation results."""
29
+ masks: np.ndarray # Shape: (N, H, W) boolean masks
30
+ boxes: np.ndarray # Shape: (N, 4) bounding boxes [x1, y1, x2, y2]
31
+ scores: np.ndarray # Shape: (N,) confidence scores
32
+ prompts: List[str] # Prompts used for each detection
33
+ prompt_indices: np.ndarray # Which prompt each mask corresponds to
34
+
35
+
36
+ class SAM3Segmenter:
37
+ """
38
+ SAM 3 based segmentation for plant disease detection.
39
+
40
+ Uses text prompts to detect and segment disease symptoms in plant images.
41
+ SAM 3's Promptable Concept Segmentation (PCS) enables open-vocabulary
42
+ detection without fine-tuning.
43
+
44
+ Example:
45
+ >>> segmenter = SAM3Segmenter("models/sam3/sam3.pt")
46
+ >>> result = segmenter.segment_with_concepts(
47
+ ... "leaf_image.jpg",
48
+ ... ["leaf with brown spots", "healthy leaf"]
49
+ ... )
50
+ >>> print(f"Found {len(result.masks)} regions")
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ checkpoint_path: str = "models/sam3/sam3.pt",
56
+ config_path: str = "configs/sam3_config.yaml",
57
+ device: Optional[str] = None
58
+ ):
59
+ """
60
+ Initialize SAM 3 segmenter.
61
+
62
+ Args:
63
+ checkpoint_path: Path to SAM 3 checkpoint
64
+ config_path: Path to configuration file
65
+ device: Device to use (cuda, cpu, mps). Auto-detected if None.
66
+ """
67
+ self.checkpoint_path = Path(checkpoint_path)
68
+ self.config = self._load_config(config_path)
69
+
70
+ # Set device
71
+ if device is None:
72
+ if torch.cuda.is_available():
73
+ self.device = "cuda"
74
+ elif torch.backends.mps.is_available():
75
+ self.device = "mps"
76
+ else:
77
+ self.device = "cpu"
78
+ else:
79
+ self.device = device
80
+
81
+ logger.info(f"Using device: {self.device}")
82
+
83
+ # Model will be loaded lazily
84
+ self.model = None
85
+ self.processor = None
86
+
87
+ def _load_config(self, config_path: str) -> dict:
88
+ """Load configuration from YAML file."""
89
+ config_path = Path(config_path)
90
+ if config_path.exists():
91
+ with open(config_path, 'r') as f:
92
+ return yaml.safe_load(f)
93
+ else:
94
+ logger.warning(f"Config not found at {config_path}, using defaults")
95
+ return self._default_config()
96
+
97
+ def _default_config(self) -> dict:
98
+ """Return default configuration."""
99
+ return {
100
+ "inference": {
101
+ "confidence_threshold": 0.25,
102
+ "presence_threshold": 0.5,
103
+ "max_objects_per_prompt": 50,
104
+ "min_mask_area": 100
105
+ },
106
+ "visualization": {
107
+ "mask_alpha": 0.5,
108
+ "show_confidence": True
109
+ }
110
+ }
111
+
112
+ def load_model(self):
113
+ """Load SAM 3 model and processor."""
114
+ if self.model is not None:
115
+ return
116
+
117
+ logger.info("Loading SAM 3 model...")
118
+
119
+ try:
120
+ # Import SAM 3 modules
121
+ from sam3.model_builder import build_sam3_image_model
122
+ from sam3.model.sam3_image_processor import Sam3Processor
123
+
124
+ # Build model
125
+ self.model = build_sam3_image_model(checkpoint=str(self.checkpoint_path))
126
+ self.model.to(self.device)
127
+
128
+ if self.config.get("model", {}).get("half_precision", True) and self.device == "cuda":
129
+ self.model = self.model.half()
130
+
131
+ self.model.eval()
132
+
133
+ # Create processor
134
+ self.processor = Sam3Processor(self.model)
135
+
136
+ logger.info("SAM 3 model loaded successfully")
137
+
138
+ except ImportError:
139
+ logger.error("SAM 3 not installed. Please install from: https://github.com/facebookresearch/sam3")
140
+ raise
141
+ except FileNotFoundError:
142
+ logger.error(f"Checkpoint not found at {self.checkpoint_path}")
143
+ raise
144
+
145
+ def segment_with_concepts(
146
+ self,
147
+ image: Union[str, Path, Image.Image, np.ndarray],
148
+ text_prompts: List[str],
149
+ confidence_threshold: Optional[float] = None
150
+ ) -> SegmentationResult:
151
+ """
152
+ Segment image using text prompts.
153
+
154
+ Args:
155
+ image: Input image (path, PIL Image, or numpy array)
156
+ text_prompts: List of text prompts describing concepts to detect
157
+ confidence_threshold: Override default confidence threshold
158
+
159
+ Returns:
160
+ SegmentationResult containing masks, boxes, scores, and prompt info
161
+ """
162
+ # Ensure model is loaded
163
+ self.load_model()
164
+
165
+ # Load image
166
+ if isinstance(image, (str, Path)):
167
+ image = Image.open(image).convert("RGB")
168
+ elif isinstance(image, np.ndarray):
169
+ image = Image.fromarray(image)
170
+
171
+ # Get threshold
172
+ threshold = confidence_threshold or self.config["inference"]["confidence_threshold"]
173
+
174
+ # Set image in processor
175
+ inference_state = self.processor.set_image(image)
176
+
177
+ # Collect results from all prompts
178
+ all_masks = []
179
+ all_boxes = []
180
+ all_scores = []
181
+ all_prompt_indices = []
182
+
183
+ for prompt_idx, prompt in enumerate(text_prompts):
184
+ logger.debug(f"Processing prompt: {prompt}")
185
+
186
+ # Get segmentation for this prompt
187
+ output = self.processor.set_text_prompt(
188
+ state=inference_state,
189
+ prompt=prompt
190
+ )
191
+
192
+ masks = output["masks"]
193
+ boxes = output["boxes"]
194
+ scores = output["scores"]
195
+
196
+ if masks is not None and len(masks) > 0:
197
+ # Convert to numpy
198
+ masks_np = masks.cpu().numpy() if torch.is_tensor(masks) else masks
199
+ boxes_np = boxes.cpu().numpy() if torch.is_tensor(boxes) else boxes
200
+ scores_np = scores.cpu().numpy() if torch.is_tensor(scores) else scores
201
+
202
+ # Filter by confidence
203
+ mask = scores_np >= threshold
204
+
205
+ if mask.any():
206
+ all_masks.append(masks_np[mask])
207
+ all_boxes.append(boxes_np[mask])
208
+ all_scores.append(scores_np[mask])
209
+ all_prompt_indices.append(
210
+ np.full(mask.sum(), prompt_idx, dtype=np.int32)
211
+ )
212
+
213
+ # Combine results
214
+ if all_masks:
215
+ combined_masks = np.concatenate(all_masks, axis=0)
216
+ combined_boxes = np.concatenate(all_boxes, axis=0)
217
+ combined_scores = np.concatenate(all_scores, axis=0)
218
+ combined_indices = np.concatenate(all_prompt_indices, axis=0)
219
+ else:
220
+ # Return empty results
221
+ h, w = np.array(image).shape[:2]
222
+ combined_masks = np.zeros((0, h, w), dtype=bool)
223
+ combined_boxes = np.zeros((0, 4), dtype=np.float32)
224
+ combined_scores = np.zeros((0,), dtype=np.float32)
225
+ combined_indices = np.zeros((0,), dtype=np.int32)
226
+
227
+ return SegmentationResult(
228
+ masks=combined_masks,
229
+ boxes=combined_boxes,
230
+ scores=combined_scores,
231
+ prompts=text_prompts,
232
+ prompt_indices=combined_indices
233
+ )
234
+
235
+ def segment_disease_regions(
236
+ self,
237
+ image: Union[str, Path, Image.Image, np.ndarray],
238
+ profile: str = "standard"
239
+ ) -> SegmentationResult:
240
+ """
241
+ Segment disease regions using predefined prompt profiles.
242
+
243
+ Args:
244
+ image: Input image
245
+ profile: Analysis profile ("quick_scan", "standard", "comprehensive", "pest_focused")
246
+
247
+ Returns:
248
+ SegmentationResult for the specified analysis profile
249
+ """
250
+ profiles = self.config.get("analysis_profiles", {})
251
+
252
+ if profile not in profiles:
253
+ available = list(profiles.keys())
254
+ raise ValueError(f"Profile '{profile}' not found. Available: {available}")
255
+
256
+ prompts = profiles[profile]["prompts"]
257
+ logger.info(f"Using profile '{profile}' with {len(prompts)} prompts")
258
+
259
+ return self.segment_with_concepts(image, prompts)
260
+
261
+ def calculate_affected_area(
262
+ self,
263
+ result: SegmentationResult,
264
+ healthy_prompt_idx: Optional[int] = None
265
+ ) -> Dict[str, float]:
266
+ """
267
+ Calculate the percentage of affected area.
268
+
269
+ Args:
270
+ result: Segmentation result
271
+ healthy_prompt_idx: Index of the "healthy" prompt for comparison
272
+
273
+ Returns:
274
+ Dictionary with area statistics
275
+ """
276
+ if len(result.masks) == 0:
277
+ return {"total_affected_percent": 0.0, "per_symptom": {}}
278
+
279
+ # Total image area
280
+ h, w = result.masks[0].shape
281
+ total_area = h * w
282
+
283
+ # Calculate areas per prompt
284
+ per_symptom = {}
285
+ total_diseased_area = 0
286
+ healthy_area = 0
287
+
288
+ for prompt_idx, prompt in enumerate(result.prompts):
289
+ mask_indices = result.prompt_indices == prompt_idx
290
+ if mask_indices.any():
291
+ combined_mask = result.masks[mask_indices].any(axis=0)
292
+ area = combined_mask.sum()
293
+ percent = (area / total_area) * 100
294
+ per_symptom[prompt] = percent
295
+
296
+ if healthy_prompt_idx is not None and prompt_idx == healthy_prompt_idx:
297
+ healthy_area = area
298
+ else:
299
+ total_diseased_area += area
300
+
301
+ # Calculate total affected (excluding overlaps approximation)
302
+ all_diseased_mask = np.zeros((h, w), dtype=bool)
303
+ for prompt_idx, prompt in enumerate(result.prompts):
304
+ if healthy_prompt_idx is None or prompt_idx != healthy_prompt_idx:
305
+ mask_indices = result.prompt_indices == prompt_idx
306
+ if mask_indices.any():
307
+ all_diseased_mask |= result.masks[mask_indices].any(axis=0)
308
+
309
+ affected_percent = (all_diseased_mask.sum() / total_area) * 100
310
+
311
+ return {
312
+ "total_affected_percent": affected_percent,
313
+ "per_symptom": per_symptom,
314
+ "healthy_percent": (healthy_area / total_area) * 100 if healthy_prompt_idx else None
315
+ }
316
+
317
+ def get_disease_prompts(self, category: str = "all") -> List[str]:
318
+ """
319
+ Get predefined disease detection prompts.
320
+
321
+ Args:
322
+ category: Prompt category ("general", "fungal", "bacterial",
323
+ "viral", "nutrient", "pest", or "all")
324
+
325
+ Returns:
326
+ List of prompts for the specified category
327
+ """
328
+ prompts_config = self.config.get("prompts", {})
329
+
330
+ if category == "all":
331
+ all_prompts = []
332
+ for cat_prompts in prompts_config.values():
333
+ all_prompts.extend(cat_prompts)
334
+ return all_prompts
335
+ elif category in prompts_config:
336
+ return prompts_config[category]
337
+ else:
338
+ available = list(prompts_config.keys()) + ["all"]
339
+ raise ValueError(f"Category '{category}' not found. Available: {available}")
340
+
341
+
342
+ class MockSAM3Segmenter(SAM3Segmenter):
343
+ """
344
+ Color-based segmentation for plant disease detection.
345
+
346
+ Analyzes actual image colors to detect disease symptoms:
347
+ - Green regions = healthy tissue
348
+ - Brown/yellow/spotted regions = potential disease
349
+
350
+ Uses scipy.ndimage for blob detection on non-green regions.
351
+ """
352
+
353
+ def load_model(self):
354
+ """Skip model loading for color-based analysis."""
355
+ logger.info("Using MockSAM3Segmenter (color-based analysis)")
356
+ self.model = "color_analysis"
357
+ self.processor = "color_analysis"
358
+
359
+ def _compute_hsv(self, img_array: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
360
+ """Convert RGB image to HSV channels."""
361
+ r, g, b = img_array[:,:,0], img_array[:,:,1], img_array[:,:,2]
362
+
363
+ rgb_max = np.maximum(np.maximum(r, g), b)
364
+ rgb_min = np.minimum(np.minimum(r, g), b)
365
+ delta = (rgb_max - rgb_min).astype(np.float32) + 1e-10
366
+
367
+ # Value
368
+ v = rgb_max
369
+
370
+ # Saturation
371
+ s = np.where(rgb_max > 0, (delta / (rgb_max.astype(np.float32) + 1e-10)) * 255, 0).astype(np.uint8)
372
+
373
+ # Hue
374
+ h_channel = np.zeros_like(r, dtype=np.float32)
375
+
376
+ mask_r = (rgb_max == r)
377
+ h_channel[mask_r] = 60 * (((g[mask_r].astype(np.float32) - b[mask_r]) / delta[mask_r]) % 6)
378
+
379
+ mask_g = (rgb_max == g) & ~mask_r
380
+ h_channel[mask_g] = 60 * (((b[mask_g].astype(np.float32) - r[mask_g]) / delta[mask_g]) + 2)
381
+
382
+ mask_b = (rgb_max == b) & ~mask_r & ~mask_g
383
+ h_channel[mask_b] = 60 * (((r[mask_b].astype(np.float32) - g[mask_b]) / delta[mask_b]) + 4)
384
+
385
+ h_channel = (h_channel / 2).astype(np.uint8) # 0-180 range
386
+
387
+ return h_channel, s, v
388
+
389
+ def _segment_leaf(self, img_array: np.ndarray, h: np.ndarray, s: np.ndarray, v: np.ndarray) -> np.ndarray:
390
+ """
391
+ Segment the leaf/plant tissue from the background.
392
+
393
+ Uses color analysis to find plant material:
394
+ - High saturation (plants are colorful, backgrounds are often gray/neutral)
395
+ - Green to yellow-brown hue range (plant tissue colors)
396
+ - Reasonable brightness
397
+
398
+ Returns the largest connected region as the leaf mask.
399
+ """
400
+ from scipy import ndimage
401
+
402
+ img_h, img_w = img_array.shape[:2]
403
+
404
+ # Plant tissue typically has:
405
+ # 1. Good saturation (colorful, not gray)
406
+ # 2. Hue in plant range: green (35-85) OR yellow/brown diseased (10-45)
407
+ # 3. Reasonable brightness
408
+
409
+ # Broad plant color range (green to brown/yellow)
410
+ plant_hue_mask = (
411
+ ((h >= 15) & (h <= 90)) | # Green to yellow-green
412
+ ((h >= 5) & (h <= 30)) # Brown/orange (diseased tissue)
413
+ )
414
+
415
+ # Plant tissue has good saturation and brightness
416
+ plant_saturation_mask = (s >= 25) # Saturated (not gray)
417
+ plant_brightness_mask = (v >= 30) & (v <= 250) # Not too dark, not blown out
418
+
419
+ # Combine criteria
420
+ potential_leaf = plant_hue_mask & plant_saturation_mask & plant_brightness_mask
421
+
422
+ # Also include high saturation areas regardless of hue (catches more plant tissue)
423
+ high_saturation = (s >= 50) & plant_brightness_mask
424
+ potential_leaf = potential_leaf | high_saturation
425
+
426
+ # Clean up with morphological operations
427
+ potential_leaf = ndimage.binary_closing(potential_leaf, iterations=3)
428
+ potential_leaf = ndimage.binary_opening(potential_leaf, iterations=2)
429
+ potential_leaf = ndimage.binary_fill_holes(potential_leaf)
430
+
431
+ # Find the largest connected component (main leaf)
432
+ labeled, num_features = ndimage.label(potential_leaf)
433
+
434
+ if num_features == 0:
435
+ # No leaf found - return full image as fallback
436
+ logger.warning("No leaf detected, using full image")
437
+ return np.ones((img_h, img_w), dtype=bool)
438
+
439
+ # Find largest component
440
+ component_sizes = ndimage.sum(potential_leaf, labeled, range(1, num_features + 1))
441
+ largest_idx = np.argmax(component_sizes) + 1
442
+ leaf_mask = (labeled == largest_idx)
443
+
444
+ # Leaf should cover at least 10% of image to be valid
445
+ leaf_coverage = leaf_mask.sum() / (img_h * img_w)
446
+ if leaf_coverage < 0.10:
447
+ logger.warning(f"Leaf too small ({leaf_coverage:.1%}), using full image")
448
+ return np.ones((img_h, img_w), dtype=bool)
449
+
450
+ logger.debug(f"Leaf segmented: {leaf_coverage:.1%} of image")
451
+ return leaf_mask
452
+
453
+ def _detect_disease_regions(self, image: Image.Image) -> Tuple[np.ndarray, List[Dict]]:
454
+ """
455
+ Detect disease regions based on color analysis.
456
+
457
+ First segments the leaf from background, then analyzes only
458
+ the leaf area for disease symptoms.
459
+
460
+ Returns:
461
+ Tuple of (binary mask of all abnormal regions, list of blob info dicts)
462
+ """
463
+ from scipy import ndimage
464
+
465
+ img_array = np.array(image)
466
+ img_h, img_w = img_array.shape[:2]
467
+
468
+ # Compute HSV
469
+ h_channel, s, v = self._compute_hsv(img_array)
470
+
471
+ # Step 1: Segment the leaf from background
472
+ leaf_mask = self._segment_leaf(img_array, h_channel, s, v)
473
+ leaf_area = leaf_mask.sum()
474
+
475
+ if leaf_area == 0:
476
+ return np.zeros((img_h, img_w), dtype=bool), []
477
+
478
+ # Step 2: Within the leaf, find healthy green regions
479
+ green_mask = (
480
+ (h_channel >= 35) & (h_channel <= 85) & # Green hue
481
+ (s >= 30) & # Saturated
482
+ (v >= 30) & # Not too dark
483
+ leaf_mask # Only within leaf
484
+ )
485
+
486
+ green_area = green_mask.sum()
487
+ green_ratio = green_area / leaf_area
488
+
489
+ logger.debug(f"Within leaf - Green: {green_ratio:.1%}, Leaf area: {leaf_area}px")
490
+
491
+ # Step 3: Define disease colors (only within leaf!)
492
+ # Brown spots: low hue, moderate saturation
493
+ brown_mask = (
494
+ (h_channel >= 5) & (h_channel <= 25) &
495
+ (s >= 30) &
496
+ (v >= 40) & (v <= 200) &
497
+ leaf_mask
498
+ )
499
+
500
+ # Yellow/chlorosis: yellow hue, high saturation
501
+ yellow_mask = (
502
+ (h_channel >= 20) & (h_channel <= 40) &
503
+ (s >= 40) &
504
+ (v >= 80) &
505
+ leaf_mask
506
+ )
507
+
508
+ # Necrotic dark spots (within leaf only)
509
+ dark_spots = (
510
+ (v <= 60) &
511
+ (s >= 15) & # Some color, not pure black
512
+ leaf_mask
513
+ )
514
+
515
+ # White spots (powdery mildew) - within leaf
516
+ white_spots = (
517
+ (v >= 200) &
518
+ (s <= 40) &
519
+ leaf_mask
520
+ )
521
+
522
+ # Combine abnormal regions
523
+ abnormal_mask = (brown_mask | yellow_mask | dark_spots | white_spots)
524
+ abnormal_area = abnormal_mask.sum()
525
+
526
+ logger.debug(f"Abnormal pixels within leaf: {abnormal_area} ({abnormal_area/leaf_area:.1%} of leaf)")
527
+
528
+ # If mostly green (>80% of leaf is green), consider healthy
529
+ if green_ratio > 0.80 and abnormal_area < leaf_area * 0.05:
530
+ logger.info(f"Leaf appears healthy ({green_ratio:.0%} green)")
531
+ return np.zeros((img_h, img_w), dtype=bool), []
532
+
533
+ # If very little abnormal tissue, also healthy
534
+ if abnormal_area < leaf_area * 0.02:
535
+ logger.info("Minimal abnormal tissue detected - healthy")
536
+ return np.zeros((img_h, img_w), dtype=bool), []
537
+
538
+ # Clean up the abnormal mask
539
+ abnormal_mask = ndimage.binary_opening(abnormal_mask, iterations=1)
540
+ abnormal_mask = ndimage.binary_closing(abnormal_mask, iterations=2)
541
+
542
+ # Label connected components
543
+ labeled_array, num_features = ndimage.label(abnormal_mask)
544
+
545
+ # Filter blobs by size (relative to leaf, not image)
546
+ min_blob_area = max(50, leaf_area * 0.005) # At least 0.5% of leaf
547
+ max_blob_area = leaf_area * 0.6 # At most 60% of leaf
548
+
549
+ blobs = []
550
+ for label_idx in range(1, num_features + 1):
551
+ blob_mask = (labeled_array == label_idx)
552
+ blob_area = blob_mask.sum()
553
+
554
+ if min_blob_area <= blob_area <= max_blob_area:
555
+ # Get bounding box
556
+ rows = np.any(blob_mask, axis=1)
557
+ cols = np.any(blob_mask, axis=0)
558
+ y_min, y_max = np.where(rows)[0][[0, -1]]
559
+ x_min, x_max = np.where(cols)[0][[0, -1]]
560
+
561
+ # Calculate confidence based on color
562
+ blob_region = img_array[blob_mask]
563
+ avg_color = blob_region.mean(axis=0)
564
+
565
+ r_ratio = avg_color[0] / 255
566
+ g_ratio = avg_color[1] / 255
567
+ b_ratio = avg_color[2] / 255
568
+
569
+ # Score: more brown/yellow = higher confidence
570
+ color_score = r_ratio - 0.5 * g_ratio + 0.3 * (1 - b_ratio)
571
+ color_score = np.clip(color_score, 0, 1)
572
+
573
+ # Area score relative to leaf
574
+ area_ratio = blob_area / leaf_area
575
+ area_score = min(1.0, area_ratio * 10)
576
+
577
+ confidence = 0.4 + 0.4 * color_score + 0.2 * area_score
578
+ confidence = np.clip(confidence, 0.3, 0.95)
579
+
580
+ blobs.append({
581
+ 'mask': blob_mask,
582
+ 'bbox': [x_min, y_min, x_max, y_max],
583
+ 'area': blob_area,
584
+ 'confidence': float(confidence)
585
+ })
586
+
587
+ return abnormal_mask, blobs
588
+
589
+ def segment_with_concepts(
590
+ self,
591
+ image: Union[str, Path, Image.Image, np.ndarray],
592
+ text_prompts: List[str],
593
+ confidence_threshold: Optional[float] = None
594
+ ) -> SegmentationResult:
595
+ """
596
+ Segment disease regions based on color analysis.
597
+
598
+ Analyzes the image colors to detect abnormal (non-green) regions
599
+ that may indicate disease. Returns empty results for healthy images.
600
+ """
601
+ # Load image
602
+ if isinstance(image, (str, Path)):
603
+ image = Image.open(image).convert("RGB")
604
+ elif isinstance(image, np.ndarray):
605
+ image = Image.fromarray(image)
606
+
607
+ w, h = image.size
608
+ threshold = confidence_threshold or self.config["inference"]["confidence_threshold"]
609
+
610
+ # Detect disease regions based on color
611
+ abnormal_mask, blobs = self._detect_disease_regions(image)
612
+
613
+ # Filter by confidence threshold
614
+ blobs = [b for b in blobs if b['confidence'] >= threshold]
615
+
616
+ if not blobs:
617
+ logger.info("No disease regions detected (healthy image)")
618
+ return SegmentationResult(
619
+ masks=np.zeros((0, h, w), dtype=bool),
620
+ boxes=np.zeros((0, 4), dtype=np.float32),
621
+ scores=np.zeros((0,), dtype=np.float32),
622
+ prompts=text_prompts,
623
+ prompt_indices=np.zeros((0,), dtype=np.int32)
624
+ )
625
+
626
+ # Convert blobs to arrays
627
+ num_detections = len(blobs)
628
+ masks = np.zeros((num_detections, h, w), dtype=bool)
629
+ boxes = np.zeros((num_detections, 4), dtype=np.float32)
630
+ scores = np.zeros(num_detections, dtype=np.float32)
631
+
632
+ # Assign detections to first disease-related prompt (skip "healthy" prompts)
633
+ disease_prompt_idx = 0
634
+ for idx, prompt in enumerate(text_prompts):
635
+ if "healthy" not in prompt.lower():
636
+ disease_prompt_idx = idx
637
+ break
638
+
639
+ prompt_indices = np.full(num_detections, disease_prompt_idx, dtype=np.int32)
640
+
641
+ for i, blob in enumerate(blobs):
642
+ masks[i] = blob['mask']
643
+ boxes[i] = blob['bbox']
644
+ scores[i] = blob['confidence']
645
+
646
+ logger.info(f"Detected {num_detections} disease region(s)")
647
+
648
+ return SegmentationResult(
649
+ masks=masks,
650
+ boxes=boxes,
651
+ scores=scores,
652
+ prompts=text_prompts,
653
+ prompt_indices=prompt_indices
654
+ )
655
+
656
+
657
+ class RFDETRSegmenter(SAM3Segmenter):
658
+ """
659
+ RF-DETR based object detection for plant disease detection.
660
+
661
+ Uses a trained RF-DETR model (DETR-based detector) instead of SAM 3.
662
+ RF-DETR is trained on annotated plant disease datasets with bounding boxes.
663
+
664
+ Example:
665
+ >>> segmenter = RFDETRSegmenter("models/rfdetr/best.pt")
666
+ >>> result = segmenter.segment_with_concepts(image, ["disease"])
667
+ >>> print(f"Found {len(result.masks)} disease regions")
668
+ """
669
+
670
+ def __init__(
671
+ self,
672
+ checkpoint_path: str = "models/rfdetr/best.pt",
673
+ config_path: str = "configs/sam3_config.yaml",
674
+ device: Optional[str] = None,
675
+ model_size: str = "medium"
676
+ ):
677
+ """
678
+ Initialize RF-DETR segmenter.
679
+
680
+ Args:
681
+ checkpoint_path: Path to trained RF-DETR checkpoint
682
+ config_path: Path to configuration file
683
+ device: Device to use (auto-detected if None)
684
+ model_size: RF-DETR model size (nano, small, medium, base)
685
+ """
686
+ super().__init__(checkpoint_path, config_path, device)
687
+ self.model_size = model_size
688
+ self.class_names = ["Pestalotiopsis"] # Default class, updated after loading
689
+
690
+ def load_model(self):
691
+ """Load RF-DETR model."""
692
+ if self.model is not None:
693
+ return
694
+
695
+ logger.info(f"Loading RF-DETR {self.model_size} model...")
696
+
697
+ try:
698
+ # Import RF-DETR
699
+ if self.model_size == "nano":
700
+ from rfdetr import RFDETRNano as RFDETRModel
701
+ elif self.model_size == "small":
702
+ from rfdetr import RFDETRSmall as RFDETRModel
703
+ elif self.model_size == "medium":
704
+ from rfdetr import RFDETRMedium as RFDETRModel
705
+ else:
706
+ from rfdetr import RFDETRBase as RFDETRModel
707
+
708
+ # Load model with custom weights if available
709
+ checkpoint = Path(self.checkpoint_path)
710
+ if checkpoint.exists():
711
+ logger.info(f"Loading custom weights from {checkpoint}")
712
+ self.model = RFDETRModel(pretrain_weights=str(checkpoint))
713
+ else:
714
+ logger.warning(f"Checkpoint not found at {checkpoint}, using pretrained weights")
715
+ self.model = RFDETRModel()
716
+
717
+ logger.info("RF-DETR model loaded successfully")
718
+
719
+ except ImportError as e:
720
+ logger.error(f"RF-DETR not installed: {e}")
721
+ logger.error("Install with: pip install rfdetr")
722
+ raise
723
+
724
+ def segment_with_concepts(
725
+ self,
726
+ image: Union[str, Path, Image.Image, np.ndarray],
727
+ text_prompts: List[str],
728
+ confidence_threshold: Optional[float] = None
729
+ ) -> SegmentationResult:
730
+ """
731
+ Detect disease regions using RF-DETR.
732
+
733
+ Note: RF-DETR is class-based (not prompt-based), so text_prompts
734
+ are ignored. The model detects all trained disease classes.
735
+
736
+ Args:
737
+ image: Input image
738
+ text_prompts: Ignored (RF-DETR uses trained classes)
739
+ confidence_threshold: Detection confidence threshold
740
+
741
+ Returns:
742
+ SegmentationResult with detected disease regions
743
+ """
744
+ self.load_model()
745
+
746
+ # Load image
747
+ if isinstance(image, (str, Path)):
748
+ pil_image = Image.open(image).convert("RGB")
749
+ elif isinstance(image, np.ndarray):
750
+ pil_image = Image.fromarray(image)
751
+ else:
752
+ pil_image = image
753
+
754
+ w, h = pil_image.size
755
+ threshold = confidence_threshold or self.config["inference"]["confidence_threshold"]
756
+
757
+ # Run RF-DETR detection
758
+ try:
759
+ detections = self.model.predict(pil_image, threshold=threshold)
760
+ except Exception as e:
761
+ logger.error(f"RF-DETR prediction failed: {e}")
762
+ return SegmentationResult(
763
+ masks=np.zeros((0, h, w), dtype=bool),
764
+ boxes=np.zeros((0, 4), dtype=np.float32),
765
+ scores=np.zeros((0,), dtype=np.float32),
766
+ prompts=text_prompts,
767
+ prompt_indices=np.zeros((0,), dtype=np.int32)
768
+ )
769
+
770
+ # Extract detections from supervision Detections object
771
+ num_detections = len(detections)
772
+
773
+ if num_detections == 0:
774
+ logger.info("No disease regions detected")
775
+ return SegmentationResult(
776
+ masks=np.zeros((0, h, w), dtype=bool),
777
+ boxes=np.zeros((0, 4), dtype=np.float32),
778
+ scores=np.zeros((0,), dtype=np.float32),
779
+ prompts=text_prompts,
780
+ prompt_indices=np.zeros((0,), dtype=np.int32)
781
+ )
782
+
783
+ # Get bounding boxes and scores
784
+ boxes = detections.xyxy.astype(np.float32) # [x1, y1, x2, y2]
785
+ scores = detections.confidence.astype(np.float32)
786
+
787
+ # Create masks from bounding boxes (RF-DETR gives boxes, not masks)
788
+ masks = np.zeros((num_detections, h, w), dtype=bool)
789
+ for i, box in enumerate(boxes):
790
+ x1, y1, x2, y2 = box.astype(int)
791
+ x1, y1 = max(0, x1), max(0, y1)
792
+ x2, y2 = min(w, x2), min(h, y2)
793
+ masks[i, y1:y2, x1:x2] = True
794
+
795
+ # Assign to first disease prompt
796
+ disease_prompt_idx = 0
797
+ for idx, prompt in enumerate(text_prompts):
798
+ if "healthy" not in prompt.lower():
799
+ disease_prompt_idx = idx
800
+ break
801
+
802
+ prompt_indices = np.full(num_detections, disease_prompt_idx, dtype=np.int32)
803
+
804
+ logger.info(f"RF-DETR detected {num_detections} disease region(s)")
805
+
806
+ return SegmentationResult(
807
+ masks=masks,
808
+ boxes=boxes,
809
+ scores=scores,
810
+ prompts=text_prompts,
811
+ prompt_indices=prompt_indices
812
+ )
813
+
814
+
815
+ def create_segmenter(
816
+ checkpoint_path: str = "models/sam3/sam3.pt",
817
+ config_path: str = "configs/sam3_config.yaml",
818
+ use_mock: bool = False,
819
+ use_rfdetr: bool = False,
820
+ rfdetr_checkpoint: str = "models/rfdetr/best.pt",
821
+ rfdetr_model_size: str = "medium"
822
+ ) -> SAM3Segmenter:
823
+ """
824
+ Factory function to create appropriate segmenter.
825
+
826
+ Args:
827
+ checkpoint_path: Path to SAM 3 checkpoint
828
+ config_path: Path to configuration
829
+ use_mock: If True, use color-based mock segmenter
830
+ use_rfdetr: If True, use RF-DETR detector
831
+ rfdetr_checkpoint: Path to RF-DETR checkpoint
832
+ rfdetr_model_size: RF-DETR model size (nano, small, medium, base)
833
+
834
+ Returns:
835
+ SAM3Segmenter, MockSAM3Segmenter, or RFDETRSegmenter instance
836
+ """
837
+ if use_rfdetr:
838
+ return RFDETRSegmenter(
839
+ checkpoint_path=rfdetr_checkpoint,
840
+ config_path=config_path,
841
+ model_size=rfdetr_model_size
842
+ )
843
+ elif use_mock:
844
+ return MockSAM3Segmenter(checkpoint_path, config_path)
845
+ else:
846
+ return SAM3Segmenter(checkpoint_path, config_path)
847
+
848
+
849
+ if __name__ == "__main__":
850
+ # Quick test with mock
851
+ segmenter = create_segmenter(use_mock=True)
852
+
853
+ # Create a test image
854
+ test_image = Image.new("RGB", (640, 480), color=(34, 139, 34)) # Forest green
855
+
856
+ prompts = ["diseased leaf", "brown spots", "healthy tissue"]
857
+ result = segmenter.segment_with_concepts(test_image, prompts)
858
+
859
+ print(f"Found {len(result.masks)} regions")
860
+ print(f"Scores: {result.scores}")
861
+ print(f"Prompts used: {[result.prompts[i] for i in result.prompt_indices]}")
862
+
863
+ areas = segmenter.calculate_affected_area(result)
864
+ print(f"Affected area: {areas['total_affected_percent']:.1f}%")
src/severity_classifier.py ADDED
@@ -0,0 +1,590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Severity Classifier Module for CropDoctor-Semantic
3
+ ===================================================
4
+
5
+ This module provides a CNN-based classifier to assess the severity
6
+ of plant diseases from segmented regions.
7
+
8
+ Severity Levels:
9
+ 0 - Healthy: No disease symptoms
10
+ 1 - Mild: <10% affected area, early stage
11
+ 2 - Moderate: 10-30% affected, established infection
12
+ 3 - Severe: >30% affected, critical intervention needed
13
+ """
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from torch.utils.data import Dataset, DataLoader
19
+ import torchvision.transforms as transforms
20
+ from torchvision import models
21
+ import numpy as np
22
+ from PIL import Image
23
+ from pathlib import Path
24
+ from typing import Tuple, Dict, List, Optional, Union
25
+ from dataclasses import dataclass
26
+ import logging
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ @dataclass
32
+ class SeverityPrediction:
33
+ """Container for severity classification results."""
34
+ severity_level: int # 0-3
35
+ severity_label: str # "healthy", "mild", "moderate", "severe"
36
+ confidence: float # 0-1
37
+ probabilities: Dict[str, float] # Per-class probabilities
38
+ affected_area_percent: float # From mask analysis
39
+
40
+
41
+ # Severity level mapping
42
+ SEVERITY_LABELS = {
43
+ 0: "healthy",
44
+ 1: "mild",
45
+ 2: "moderate",
46
+ 3: "severe"
47
+ }
48
+
49
+ SEVERITY_DESCRIPTIONS = {
50
+ 0: "No disease symptoms detected. Plant appears healthy.",
51
+ 1: "Early stage infection. Less than 10% of tissue affected. Preventive action recommended.",
52
+ 2: "Established infection. 10-30% of tissue affected. Treatment required.",
53
+ 3: "Severe infection. Over 30% of tissue affected. Urgent intervention needed."
54
+ }
55
+
56
+
57
+ class SeverityClassifierCNN(nn.Module):
58
+ """
59
+ CNN model for disease severity classification.
60
+
61
+ Architecture options:
62
+ - EfficientNet-B0 (lightweight, fast)
63
+ - ResNet-50 (balanced)
64
+ - ConvNeXt-Tiny (modern, accurate)
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ num_classes: int = 4,
70
+ backbone: str = "efficientnet_b0",
71
+ pretrained: bool = True,
72
+ dropout: float = 0.3
73
+ ):
74
+ super().__init__()
75
+
76
+ self.num_classes = num_classes
77
+ self.backbone_name = backbone
78
+
79
+ # Load backbone
80
+ if backbone == "efficientnet_b0":
81
+ self.backbone = models.efficientnet_b0(
82
+ weights=models.EfficientNet_B0_Weights.DEFAULT if pretrained else None
83
+ )
84
+ in_features = self.backbone.classifier[1].in_features
85
+ self.backbone.classifier = nn.Sequential(
86
+ nn.Dropout(dropout),
87
+ nn.Linear(in_features, num_classes)
88
+ )
89
+
90
+ elif backbone == "resnet50":
91
+ self.backbone = models.resnet50(
92
+ weights=models.ResNet50_Weights.DEFAULT if pretrained else None
93
+ )
94
+ in_features = self.backbone.fc.in_features
95
+ self.backbone.fc = nn.Sequential(
96
+ nn.Dropout(dropout),
97
+ nn.Linear(in_features, num_classes)
98
+ )
99
+
100
+ elif backbone == "convnext_tiny":
101
+ self.backbone = models.convnext_tiny(
102
+ weights=models.ConvNeXt_Tiny_Weights.DEFAULT if pretrained else None
103
+ )
104
+ in_features = self.backbone.classifier[2].in_features
105
+ self.backbone.classifier = nn.Sequential(
106
+ nn.Flatten(1),
107
+ nn.LayerNorm(in_features),
108
+ nn.Dropout(dropout),
109
+ nn.Linear(in_features, num_classes)
110
+ )
111
+ else:
112
+ raise ValueError(f"Unknown backbone: {backbone}")
113
+
114
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
115
+ return self.backbone(x)
116
+
117
+ def predict(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
118
+ """Get predictions with probabilities."""
119
+ logits = self.forward(x)
120
+ probs = F.softmax(logits, dim=1)
121
+ preds = torch.argmax(probs, dim=1)
122
+ return preds, probs
123
+
124
+
125
+ class SeverityClassifier:
126
+ """
127
+ High-level interface for severity classification.
128
+
129
+ Handles image preprocessing, model loading, and prediction formatting.
130
+
131
+ Example:
132
+ >>> classifier = SeverityClassifier("models/severity_classifier/best.pt")
133
+ >>> result = classifier.classify("diseased_leaf.jpg")
134
+ >>> print(f"Severity: {result.severity_label} ({result.confidence:.2f})")
135
+ """
136
+
137
+ def __init__(
138
+ self,
139
+ checkpoint_path: Optional[str] = None,
140
+ device: Optional[str] = None,
141
+ image_size: int = 224
142
+ ):
143
+ """
144
+ Initialize severity classifier.
145
+
146
+ Args:
147
+ checkpoint_path: Path to trained model checkpoint
148
+ device: Device to use (auto-detected if None)
149
+ image_size: Input image size for the model
150
+ """
151
+ # Set device
152
+ if device is None:
153
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
154
+ else:
155
+ self.device = device
156
+
157
+ self.image_size = image_size
158
+ self.checkpoint_path = checkpoint_path
159
+
160
+ # Initialize model
161
+ self.model = None
162
+ self._setup_transforms()
163
+
164
+ def _setup_transforms(self):
165
+ """Setup image preprocessing transforms."""
166
+ # ImageNet normalization
167
+ self.transform = transforms.Compose([
168
+ transforms.Resize((self.image_size, self.image_size)),
169
+ transforms.ToTensor(),
170
+ transforms.Normalize(
171
+ mean=[0.485, 0.456, 0.406],
172
+ std=[0.229, 0.224, 0.225]
173
+ )
174
+ ])
175
+
176
+ # Augmentation for training
177
+ self.train_transform = transforms.Compose([
178
+ transforms.RandomResizedCrop(self.image_size, scale=(0.8, 1.0)),
179
+ transforms.RandomHorizontalFlip(),
180
+ transforms.RandomVerticalFlip(),
181
+ transforms.RandomRotation(15),
182
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
183
+ transforms.ToTensor(),
184
+ transforms.Normalize(
185
+ mean=[0.485, 0.456, 0.406],
186
+ std=[0.229, 0.224, 0.225]
187
+ )
188
+ ])
189
+
190
+ def load_model(self, backbone: str = "efficientnet_b0"):
191
+ """Load or initialize the model."""
192
+ if self.model is not None:
193
+ return
194
+
195
+ self.model = SeverityClassifierCNN(
196
+ num_classes=4,
197
+ backbone=backbone,
198
+ pretrained=True
199
+ )
200
+
201
+ if self.checkpoint_path and Path(self.checkpoint_path).exists():
202
+ logger.info(f"Loading checkpoint from {self.checkpoint_path}")
203
+ checkpoint = torch.load(self.checkpoint_path, map_location=self.device, weights_only=False)
204
+ self.model.load_state_dict(checkpoint["model_state_dict"])
205
+ else:
206
+ logger.warning("No checkpoint loaded, using pretrained backbone only")
207
+
208
+ self.model.to(self.device)
209
+ self.model.eval()
210
+
211
+ def preprocess_image(
212
+ self,
213
+ image: Union[str, Path, Image.Image, np.ndarray]
214
+ ) -> torch.Tensor:
215
+ """Preprocess image for classification."""
216
+ if isinstance(image, (str, Path)):
217
+ image = Image.open(image).convert("RGB")
218
+ elif isinstance(image, np.ndarray):
219
+ image = Image.fromarray(image)
220
+
221
+ tensor = self.transform(image)
222
+ return tensor.unsqueeze(0) # Add batch dimension
223
+
224
+ def classify(
225
+ self,
226
+ image: Union[str, Path, Image.Image, np.ndarray],
227
+ mask: Optional[np.ndarray] = None
228
+ ) -> SeverityPrediction:
229
+ """
230
+ Classify disease severity in an image.
231
+
232
+ Args:
233
+ image: Input image (path, PIL Image, or numpy array)
234
+ mask: Optional binary mask of diseased region
235
+
236
+ Returns:
237
+ SeverityPrediction with severity level, confidence, and details
238
+ """
239
+ self.load_model()
240
+
241
+ # Calculate affected area from mask
242
+ affected_percent = 0.0
243
+ if mask is not None:
244
+ affected_percent = (mask.sum() / mask.size) * 100
245
+
246
+ # Preprocess and predict
247
+ input_tensor = self.preprocess_image(image).to(self.device)
248
+
249
+ with torch.no_grad():
250
+ pred, probs = self.model.predict(input_tensor)
251
+
252
+ severity_level = pred.item()
253
+ confidence = probs[0, severity_level].item()
254
+
255
+ # Format probabilities
256
+ prob_dict = {
257
+ SEVERITY_LABELS[i]: probs[0, i].item()
258
+ for i in range(4)
259
+ }
260
+
261
+ return SeverityPrediction(
262
+ severity_level=severity_level,
263
+ severity_label=SEVERITY_LABELS[severity_level],
264
+ confidence=confidence,
265
+ probabilities=prob_dict,
266
+ affected_area_percent=affected_percent
267
+ )
268
+
269
+ def classify_region(
270
+ self,
271
+ image: Union[str, Path, Image.Image, np.ndarray],
272
+ mask: np.ndarray
273
+ ) -> SeverityPrediction:
274
+ """
275
+ Classify severity of a specific masked region.
276
+
277
+ Extracts the bounding box of the mask and classifies that region.
278
+
279
+ Args:
280
+ image: Full image
281
+ mask: Binary mask of region to classify
282
+
283
+ Returns:
284
+ SeverityPrediction for the masked region
285
+ """
286
+ # Load image if needed
287
+ if isinstance(image, (str, Path)):
288
+ image = Image.open(image).convert("RGB")
289
+ elif isinstance(image, np.ndarray):
290
+ image = Image.fromarray(image)
291
+
292
+ img_array = np.array(image)
293
+
294
+ # Get bounding box from mask
295
+ rows = np.any(mask, axis=1)
296
+ cols = np.any(mask, axis=0)
297
+
298
+ if not rows.any() or not cols.any():
299
+ # Empty mask, return healthy
300
+ return SeverityPrediction(
301
+ severity_level=0,
302
+ severity_label="healthy",
303
+ confidence=1.0,
304
+ probabilities={"healthy": 1.0, "mild": 0.0, "moderate": 0.0, "severe": 0.0},
305
+ affected_area_percent=0.0
306
+ )
307
+
308
+ y_min, y_max = np.where(rows)[0][[0, -1]]
309
+ x_min, x_max = np.where(cols)[0][[0, -1]]
310
+
311
+ # Add padding
312
+ pad = 10
313
+ y_min = max(0, y_min - pad)
314
+ y_max = min(img_array.shape[0], y_max + pad)
315
+ x_min = max(0, x_min - pad)
316
+ x_max = min(img_array.shape[1], x_max + pad)
317
+
318
+ # Crop region
319
+ cropped = img_array[y_min:y_max, x_min:x_max]
320
+ cropped_mask = mask[y_min:y_max, x_min:x_max]
321
+
322
+ return self.classify(cropped, mask=cropped_mask)
323
+
324
+ def classify_batch(
325
+ self,
326
+ images: List[Union[str, Path, Image.Image, np.ndarray]],
327
+ masks: Optional[List[np.ndarray]] = None,
328
+ batch_size: int = 16
329
+ ) -> List[SeverityPrediction]:
330
+ """
331
+ Classify multiple images in batches.
332
+
333
+ Args:
334
+ images: List of images to classify
335
+ masks: Optional list of masks for each image
336
+ batch_size: Batch size for inference
337
+
338
+ Returns:
339
+ List of SeverityPrediction for each image
340
+ """
341
+ self.load_model()
342
+
343
+ results = []
344
+
345
+ for i in range(0, len(images), batch_size):
346
+ batch_images = images[i:i + batch_size]
347
+ batch_masks = masks[i:i + batch_size] if masks else [None] * len(batch_images)
348
+
349
+ # Preprocess batch
350
+ tensors = [self.preprocess_image(img) for img in batch_images]
351
+ batch_tensor = torch.cat(tensors, dim=0).to(self.device)
352
+
353
+ # Predict
354
+ with torch.no_grad():
355
+ preds, probs = self.model.predict(batch_tensor)
356
+
357
+ # Format results
358
+ for j, (pred, prob) in enumerate(zip(preds, probs)):
359
+ mask = batch_masks[j]
360
+ affected_percent = 0.0
361
+ if mask is not None:
362
+ affected_percent = (mask.sum() / mask.size) * 100
363
+
364
+ severity_level = pred.item()
365
+
366
+ results.append(SeverityPrediction(
367
+ severity_level=severity_level,
368
+ severity_label=SEVERITY_LABELS[severity_level],
369
+ confidence=prob[severity_level].item(),
370
+ probabilities={
371
+ SEVERITY_LABELS[k]: prob[k].item()
372
+ for k in range(4)
373
+ },
374
+ affected_area_percent=affected_percent
375
+ ))
376
+
377
+ return results
378
+
379
+
380
+ class PlantDiseaseDataset(Dataset):
381
+ """
382
+ Dataset class for training severity classifier.
383
+
384
+ Expected folder structure:
385
+ data_root/
386
+ healthy/
387
+ image1.jpg
388
+ image2.jpg
389
+ mild/
390
+ ...
391
+ moderate/
392
+ ...
393
+ severe/
394
+ ...
395
+ """
396
+
397
+ def __init__(
398
+ self,
399
+ data_root: str,
400
+ transform: Optional[transforms.Compose] = None,
401
+ split: str = "train"
402
+ ):
403
+ self.data_root = Path(data_root)
404
+ self.transform = transform
405
+ self.split = split
406
+
407
+ # Collect image paths and labels
408
+ self.samples = []
409
+
410
+ for label_idx, label_name in SEVERITY_LABELS.items():
411
+ label_dir = self.data_root / label_name
412
+ if label_dir.exists():
413
+ for img_path in label_dir.glob("*.jpg"):
414
+ self.samples.append((img_path, label_idx))
415
+ for img_path in label_dir.glob("*.png"):
416
+ self.samples.append((img_path, label_idx))
417
+
418
+ logger.info(f"Loaded {len(self.samples)} samples for {split}")
419
+
420
+ def __len__(self) -> int:
421
+ return len(self.samples)
422
+
423
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
424
+ img_path, label = self.samples[idx]
425
+
426
+ image = Image.open(img_path).convert("RGB")
427
+
428
+ if self.transform:
429
+ image = self.transform(image)
430
+
431
+ return image, label
432
+
433
+
434
+ def train_classifier(
435
+ train_data_root: str,
436
+ val_data_root: str,
437
+ output_dir: str,
438
+ backbone: str = "efficientnet_b0",
439
+ epochs: int = 50,
440
+ batch_size: int = 32,
441
+ learning_rate: float = 1e-4,
442
+ device: str = "cuda"
443
+ ):
444
+ """
445
+ Train the severity classifier.
446
+
447
+ Args:
448
+ train_data_root: Path to training data
449
+ val_data_root: Path to validation data
450
+ output_dir: Where to save checkpoints
451
+ backbone: Model backbone to use
452
+ epochs: Number of training epochs
453
+ batch_size: Training batch size
454
+ learning_rate: Initial learning rate
455
+ device: Device to train on
456
+ """
457
+ output_dir = Path(output_dir)
458
+ output_dir.mkdir(parents=True, exist_ok=True)
459
+
460
+ # Setup classifier for transforms
461
+ classifier = SeverityClassifier()
462
+
463
+ # Create datasets
464
+ train_dataset = PlantDiseaseDataset(
465
+ train_data_root,
466
+ transform=classifier.train_transform,
467
+ split="train"
468
+ )
469
+ val_dataset = PlantDiseaseDataset(
470
+ val_data_root,
471
+ transform=classifier.transform,
472
+ split="val"
473
+ )
474
+
475
+ # Create dataloaders
476
+ train_loader = DataLoader(
477
+ train_dataset,
478
+ batch_size=batch_size,
479
+ shuffle=True,
480
+ num_workers=4,
481
+ pin_memory=True
482
+ )
483
+ val_loader = DataLoader(
484
+ val_dataset,
485
+ batch_size=batch_size,
486
+ shuffle=False,
487
+ num_workers=4
488
+ )
489
+
490
+ # Initialize model
491
+ model = SeverityClassifierCNN(
492
+ num_classes=4,
493
+ backbone=backbone,
494
+ pretrained=True
495
+ ).to(device)
496
+
497
+ # Loss and optimizer
498
+ criterion = nn.CrossEntropyLoss()
499
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
500
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
501
+
502
+ best_val_acc = 0.0
503
+
504
+ for epoch in range(epochs):
505
+ # Training
506
+ model.train()
507
+ train_loss = 0.0
508
+ train_correct = 0
509
+ train_total = 0
510
+
511
+ for images, labels in train_loader:
512
+ images, labels = images.to(device), labels.to(device)
513
+
514
+ optimizer.zero_grad()
515
+ outputs = model(images)
516
+ loss = criterion(outputs, labels)
517
+ loss.backward()
518
+ optimizer.step()
519
+
520
+ train_loss += loss.item()
521
+ _, predicted = outputs.max(1)
522
+ train_total += labels.size(0)
523
+ train_correct += predicted.eq(labels).sum().item()
524
+
525
+ # Validation
526
+ model.eval()
527
+ val_loss = 0.0
528
+ val_correct = 0
529
+ val_total = 0
530
+
531
+ with torch.no_grad():
532
+ for images, labels in val_loader:
533
+ images, labels = images.to(device), labels.to(device)
534
+ outputs = model(images)
535
+ loss = criterion(outputs, labels)
536
+
537
+ val_loss += loss.item()
538
+ _, predicted = outputs.max(1)
539
+ val_total += labels.size(0)
540
+ val_correct += predicted.eq(labels).sum().item()
541
+
542
+ train_acc = 100. * train_correct / train_total
543
+ val_acc = 100. * val_correct / val_total
544
+
545
+ scheduler.step()
546
+
547
+ logger.info(
548
+ f"Epoch {epoch+1}/{epochs} - "
549
+ f"Train Loss: {train_loss/len(train_loader):.4f}, Train Acc: {train_acc:.2f}% - "
550
+ f"Val Loss: {val_loss/len(val_loader):.4f}, Val Acc: {val_acc:.2f}%"
551
+ )
552
+
553
+ # Save best model
554
+ if val_acc > best_val_acc:
555
+ best_val_acc = val_acc
556
+ torch.save({
557
+ "epoch": epoch,
558
+ "model_state_dict": model.state_dict(),
559
+ "optimizer_state_dict": optimizer.state_dict(),
560
+ "val_acc": val_acc,
561
+ "backbone": backbone
562
+ }, output_dir / "best.pt")
563
+ logger.info(f"Saved best model with val_acc: {val_acc:.2f}%")
564
+
565
+ # Save final model
566
+ torch.save({
567
+ "epoch": epochs,
568
+ "model_state_dict": model.state_dict(),
569
+ "optimizer_state_dict": optimizer.state_dict(),
570
+ "val_acc": val_acc,
571
+ "backbone": backbone
572
+ }, output_dir / "final.pt")
573
+
574
+ logger.info(f"Training complete. Best val_acc: {best_val_acc:.2f}%")
575
+
576
+
577
+ if __name__ == "__main__":
578
+ # Quick test
579
+ classifier = SeverityClassifier()
580
+
581
+ # Create a test image
582
+ test_image = Image.new("RGB", (224, 224), color=(139, 69, 19)) # Brown
583
+
584
+ # Test classification (will use random weights without checkpoint)
585
+ result = classifier.classify(test_image)
586
+
587
+ print(f"Severity: {result.severity_label}")
588
+ print(f"Confidence: {result.confidence:.2f}")
589
+ print(f"Probabilities: {result.probabilities}")
590
+ print(f"Description: {SEVERITY_DESCRIPTIONS[result.severity_level]}")
src/treatment_recommender.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Treatment Recommender Module for CropDoctor-Semantic
3
+ =====================================================
4
+
5
+ This module uses Claude API to generate contextual treatment
6
+ recommendations based on disease diagnosis results.
7
+
8
+ Features:
9
+ - Disease identification from symptoms
10
+ - Treatment recommendations (organic/chemical)
11
+ - Preventive measures
12
+ - Timing and application guidance
13
+ """
14
+
15
+ import os
16
+ from typing import Dict, List, Optional, Union
17
+ from dataclasses import dataclass
18
+ import json
19
+ import logging
20
+ from pathlib import Path
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ @dataclass
26
+ class TreatmentRecommendation:
27
+ """Container for treatment recommendations."""
28
+ disease_name: str
29
+ disease_type: str # fungal, bacterial, viral, pest, nutrient
30
+ confidence: float
31
+ symptoms_matched: List[str]
32
+ organic_treatments: List[str]
33
+ chemical_treatments: List[str]
34
+ preventive_measures: List[str]
35
+ timing: str
36
+ urgency: str # low, medium, high, critical
37
+ additional_notes: str
38
+
39
+
40
+ # Disease knowledge base for offline fallback
41
+ DISEASE_DATABASE = {
42
+ "powdery_mildew": {
43
+ "name": "Powdery Mildew",
44
+ "type": "fungal",
45
+ "symptoms": ["white powdery coating", "curled leaves", "stunted growth"],
46
+ "organic": [
47
+ "Neem oil spray (1 tbsp per liter water)",
48
+ "Baking soda solution (1 tsp per liter + few drops soap)",
49
+ "Milk spray (40% milk, 60% water)",
50
+ "Sulfur-based fungicide"
51
+ ],
52
+ "chemical": [
53
+ "Myclobutanil",
54
+ "Triadimefon",
55
+ "Propiconazole"
56
+ ],
57
+ "prevention": [
58
+ "Improve air circulation",
59
+ "Avoid overhead watering",
60
+ "Remove infected plant parts",
61
+ "Plant resistant varieties"
62
+ ],
63
+ "timing": "Apply at first sign of infection, repeat every 7-14 days"
64
+ },
65
+ "leaf_spot": {
66
+ "name": "Leaf Spot Disease",
67
+ "type": "fungal",
68
+ "symptoms": ["brown spots", "circular lesions", "yellow halos"],
69
+ "organic": [
70
+ "Copper-based fungicide",
71
+ "Neem oil treatment",
72
+ "Remove and destroy infected leaves"
73
+ ],
74
+ "chemical": [
75
+ "Chlorothalonil",
76
+ "Mancozeb",
77
+ "Azoxystrobin"
78
+ ],
79
+ "prevention": [
80
+ "Water at soil level",
81
+ "Mulch around plants",
82
+ "Rotate crops annually",
83
+ "Maintain proper spacing"
84
+ ],
85
+ "timing": "Begin treatment early, apply every 7-10 days during wet weather"
86
+ },
87
+ "bacterial_blight": {
88
+ "name": "Bacterial Blight",
89
+ "type": "bacterial",
90
+ "symptoms": ["water-soaked lesions", "angular spots", "wilting"],
91
+ "organic": [
92
+ "Copper hydroxide spray",
93
+ "Remove infected plants",
94
+ "Improve drainage"
95
+ ],
96
+ "chemical": [
97
+ "Streptomycin sulfate",
98
+ "Oxytetracycline",
99
+ "Copper-based bactericides"
100
+ ],
101
+ "prevention": [
102
+ "Use disease-free seeds",
103
+ "Avoid working with wet plants",
104
+ "Sanitize tools between plants",
105
+ "Practice crop rotation"
106
+ ],
107
+ "timing": "Apply preventively or at first symptoms, difficult to cure once established"
108
+ },
109
+ "viral_mosaic": {
110
+ "name": "Viral Mosaic Disease",
111
+ "type": "viral",
112
+ "symptoms": ["mosaic pattern", "mottled leaves", "distorted growth"],
113
+ "organic": [
114
+ "Remove and destroy infected plants",
115
+ "Control aphid vectors",
116
+ "Use reflective mulch"
117
+ ],
118
+ "chemical": [
119
+ "No direct treatment available",
120
+ "Control vectors with insecticides"
121
+ ],
122
+ "prevention": [
123
+ "Plant resistant varieties",
124
+ "Control aphid populations",
125
+ "Remove weeds that harbor virus",
126
+ "Sanitize tools frequently"
127
+ ],
128
+ "timing": "Prevention is key - no cure available once infected"
129
+ },
130
+ "nutrient_deficiency": {
131
+ "name": "Nutrient Deficiency",
132
+ "type": "nutrient",
133
+ "symptoms": ["chlorosis", "yellowing", "purple coloration"],
134
+ "organic": [
135
+ "Compost application",
136
+ "Foliar feeding with seaweed extract",
137
+ "Fish emulsion fertilizer"
138
+ ],
139
+ "chemical": [
140
+ "Balanced NPK fertilizer",
141
+ "Iron chelate for iron deficiency",
142
+ "Epsom salt for magnesium deficiency"
143
+ ],
144
+ "prevention": [
145
+ "Regular soil testing",
146
+ "Maintain soil pH 6.0-7.0",
147
+ "Add organic matter annually",
148
+ "Mulch to retain nutrients"
149
+ ],
150
+ "timing": "Apply fertilizers during active growth, avoid over-fertilization"
151
+ },
152
+ "aphid_infestation": {
153
+ "name": "Aphid Infestation",
154
+ "type": "pest",
155
+ "symptoms": ["sticky residue", "curled leaves", "visible insects"],
156
+ "organic": [
157
+ "Strong water spray to dislodge",
158
+ "Neem oil spray",
159
+ "Insecticidal soap",
160
+ "Release ladybugs or lacewings"
161
+ ],
162
+ "chemical": [
163
+ "Pyrethrin-based insecticides",
164
+ "Imidacloprid (systemic)",
165
+ "Malathion"
166
+ ],
167
+ "prevention": [
168
+ "Encourage beneficial insects",
169
+ "Remove weeds",
170
+ "Avoid excess nitrogen fertilization",
171
+ "Use reflective mulch"
172
+ ],
173
+ "timing": "Treat early when populations are low, monitor regularly"
174
+ }
175
+ }
176
+
177
+
178
+ class TreatmentRecommender:
179
+ """
180
+ Generate treatment recommendations using LLM.
181
+
182
+ Uses Claude API for intelligent, contextual recommendations
183
+ with fallback to local knowledge base.
184
+
185
+ Example:
186
+ >>> recommender = TreatmentRecommender(api_key="your_key")
187
+ >>> result = recommender.get_recommendation(
188
+ ... symptoms=["brown spots", "yellowing"],
189
+ ... severity="moderate",
190
+ ... plant_species="tomato"
191
+ ... )
192
+ >>> print(result.disease_name)
193
+ >>> print(result.organic_treatments)
194
+ """
195
+
196
+ def __init__(
197
+ self,
198
+ api_key: Optional[str] = None,
199
+ use_llm: bool = True,
200
+ model: str = "claude-sonnet-4-20250514"
201
+ ):
202
+ """
203
+ Initialize treatment recommender.
204
+
205
+ Args:
206
+ api_key: Anthropic API key (uses env var ANTHROPIC_API_KEY if None)
207
+ use_llm: Whether to use LLM for recommendations
208
+ model: Claude model to use
209
+ """
210
+ self.use_llm = use_llm
211
+ self.model = model
212
+ self.client = None
213
+
214
+ if use_llm:
215
+ api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
216
+ if api_key:
217
+ try:
218
+ import anthropic
219
+ self.client = anthropic.Anthropic(api_key=api_key)
220
+ logger.info("Anthropic client initialized successfully")
221
+ except ImportError:
222
+ logger.warning("anthropic package not installed, using offline mode")
223
+ self.use_llm = False
224
+ else:
225
+ logger.warning("No API key provided, using offline knowledge base")
226
+ self.use_llm = False
227
+
228
+ def _build_prompt(
229
+ self,
230
+ symptoms: List[str],
231
+ severity: str,
232
+ plant_species: Optional[str] = None,
233
+ affected_area_percent: Optional[float] = None,
234
+ additional_context: Optional[str] = None
235
+ ) -> str:
236
+ """Build the prompt for Claude."""
237
+
238
+ prompt = f"""You are an expert plant pathologist providing diagnosis and treatment recommendations.
239
+
240
+ Based on the following observations, identify the most likely disease and provide treatment recommendations.
241
+
242
+ ## Observations:
243
+ - **Symptoms detected**: {', '.join(symptoms)}
244
+ - **Severity level**: {severity}
245
+ """
246
+
247
+ if plant_species:
248
+ prompt += f"- **Plant species**: {plant_species}\n"
249
+
250
+ if affected_area_percent is not None:
251
+ prompt += f"- **Affected area**: {affected_area_percent:.1f}%\n"
252
+
253
+ if additional_context:
254
+ prompt += f"- **Additional context**: {additional_context}\n"
255
+
256
+ prompt += """
257
+ ## Please provide:
258
+ 1. **Disease identification**: Name and type (fungal/bacterial/viral/pest/nutrient)
259
+ 2. **Confidence level**: How confident are you in this diagnosis (0-100%)
260
+ 3. **Organic treatments**: List 3-5 organic/natural treatment options
261
+ 4. **Chemical treatments**: List 2-3 chemical treatment options (if needed)
262
+ 5. **Preventive measures**: List 3-5 prevention strategies
263
+ 6. **Timing**: When and how often to apply treatments
264
+ 7. **Urgency**: How urgent is treatment (low/medium/high/critical)
265
+ 8. **Additional notes**: Any important considerations
266
+
267
+ Format your response as JSON with these exact keys:
268
+ {
269
+ "disease_name": "string",
270
+ "disease_type": "string",
271
+ "confidence": number,
272
+ "symptoms_matched": ["string"],
273
+ "organic_treatments": ["string"],
274
+ "chemical_treatments": ["string"],
275
+ "preventive_measures": ["string"],
276
+ "timing": "string",
277
+ "urgency": "string",
278
+ "additional_notes": "string"
279
+ }
280
+ """
281
+ return prompt
282
+
283
+ def get_recommendation(
284
+ self,
285
+ symptoms: List[str],
286
+ severity: str = "moderate",
287
+ plant_species: Optional[str] = None,
288
+ affected_area_percent: Optional[float] = None,
289
+ additional_context: Optional[str] = None
290
+ ) -> TreatmentRecommendation:
291
+ """
292
+ Get treatment recommendation for detected symptoms.
293
+
294
+ Args:
295
+ symptoms: List of detected symptoms
296
+ severity: Severity level (healthy, mild, moderate, severe)
297
+ plant_species: Optional plant species for more specific advice
298
+ affected_area_percent: Percentage of plant affected
299
+ additional_context: Any additional observations
300
+
301
+ Returns:
302
+ TreatmentRecommendation with diagnosis and treatment options
303
+ """
304
+ if self.use_llm and self.client:
305
+ return self._get_llm_recommendation(
306
+ symptoms, severity, plant_species,
307
+ affected_area_percent, additional_context
308
+ )
309
+ else:
310
+ return self._get_offline_recommendation(symptoms, severity)
311
+
312
+ def _get_llm_recommendation(
313
+ self,
314
+ symptoms: List[str],
315
+ severity: str,
316
+ plant_species: Optional[str],
317
+ affected_area_percent: Optional[float],
318
+ additional_context: Optional[str]
319
+ ) -> TreatmentRecommendation:
320
+ """Get recommendation from Claude API."""
321
+
322
+ prompt = self._build_prompt(
323
+ symptoms, severity, plant_species,
324
+ affected_area_percent, additional_context
325
+ )
326
+
327
+ try:
328
+ response = self.client.messages.create(
329
+ model=self.model,
330
+ max_tokens=1500,
331
+ messages=[
332
+ {"role": "user", "content": prompt}
333
+ ]
334
+ )
335
+
336
+ # Parse JSON response
337
+ content = response.content[0].text
338
+
339
+ # Extract JSON from response
340
+ start = content.find('{')
341
+ end = content.rfind('}') + 1
342
+ if start >= 0 and end > start:
343
+ json_str = content[start:end]
344
+ data = json.loads(json_str)
345
+
346
+ return TreatmentRecommendation(
347
+ disease_name=data.get("disease_name", "Unknown"),
348
+ disease_type=data.get("disease_type", "unknown"),
349
+ confidence=data.get("confidence", 0) / 100,
350
+ symptoms_matched=data.get("symptoms_matched", symptoms),
351
+ organic_treatments=data.get("organic_treatments", []),
352
+ chemical_treatments=data.get("chemical_treatments", []),
353
+ preventive_measures=data.get("preventive_measures", []),
354
+ timing=data.get("timing", "Consult local extension service"),
355
+ urgency=data.get("urgency", "medium"),
356
+ additional_notes=data.get("additional_notes", "")
357
+ )
358
+ else:
359
+ logger.warning("Could not parse LLM response, using offline mode")
360
+ return self._get_offline_recommendation(symptoms, severity)
361
+
362
+ except Exception as e:
363
+ logger.error(f"LLM request failed: {e}")
364
+ return self._get_offline_recommendation(symptoms, severity)
365
+
366
+ def _get_offline_recommendation(
367
+ self,
368
+ symptoms: List[str],
369
+ severity: str
370
+ ) -> TreatmentRecommendation:
371
+ """Get recommendation from local knowledge base."""
372
+
373
+ # Simple symptom matching
374
+ best_match = None
375
+ best_score = 0
376
+ matched_symptoms = []
377
+
378
+ symptoms_lower = [s.lower() for s in symptoms]
379
+
380
+ for disease_key, disease_data in DISEASE_DATABASE.items():
381
+ score = 0
382
+ matches = []
383
+
384
+ for symptom in disease_data["symptoms"]:
385
+ for input_symptom in symptoms_lower:
386
+ if any(word in input_symptom for word in symptom.split()):
387
+ score += 1
388
+ matches.append(symptom)
389
+ break
390
+
391
+ if score > best_score:
392
+ best_score = score
393
+ best_match = disease_key
394
+ matched_symptoms = matches
395
+
396
+ if best_match is None:
397
+ # Default to general recommendation
398
+ return TreatmentRecommendation(
399
+ disease_name="Unidentified Condition",
400
+ disease_type="unknown",
401
+ confidence=0.3,
402
+ symptoms_matched=symptoms,
403
+ organic_treatments=[
404
+ "Apply broad-spectrum organic fungicide",
405
+ "Improve plant nutrition",
406
+ "Remove affected plant parts"
407
+ ],
408
+ chemical_treatments=[
409
+ "Consult local agricultural extension",
410
+ "Get professional diagnosis"
411
+ ],
412
+ preventive_measures=[
413
+ "Improve air circulation",
414
+ "Avoid overwatering",
415
+ "Maintain plant hygiene"
416
+ ],
417
+ timing="Monitor closely and treat at first sign of worsening",
418
+ urgency="medium" if severity in ["mild", "moderate"] else "high",
419
+ additional_notes="Unable to identify specific disease. Consider professional diagnosis."
420
+ )
421
+
422
+ disease = DISEASE_DATABASE[best_match]
423
+
424
+ # Calculate urgency based on severity
425
+ urgency_map = {
426
+ "healthy": "low",
427
+ "mild": "medium",
428
+ "moderate": "high",
429
+ "severe": "critical"
430
+ }
431
+
432
+ return TreatmentRecommendation(
433
+ disease_name=disease["name"],
434
+ disease_type=disease["type"],
435
+ confidence=min(0.9, 0.5 + (best_score * 0.15)),
436
+ symptoms_matched=matched_symptoms,
437
+ organic_treatments=disease["organic"],
438
+ chemical_treatments=disease["chemical"],
439
+ preventive_measures=disease["prevention"],
440
+ timing=disease["timing"],
441
+ urgency=urgency_map.get(severity, "medium"),
442
+ additional_notes=f"Diagnosis based on {best_score} symptom matches."
443
+ )
444
+
445
+ def format_report(
446
+ self,
447
+ recommendation: TreatmentRecommendation,
448
+ include_chemical: bool = True
449
+ ) -> str:
450
+ """
451
+ Format recommendation as a readable report.
452
+
453
+ Args:
454
+ recommendation: TreatmentRecommendation to format
455
+ include_chemical: Whether to include chemical treatments
456
+
457
+ Returns:
458
+ Formatted string report
459
+ """
460
+ urgency_emoji = {
461
+ "low": "🟒",
462
+ "medium": "🟑",
463
+ "high": "🟠",
464
+ "critical": "πŸ”΄"
465
+ }
466
+
467
+ report = f"""
468
+ ╔══════════════════════════════════════════════════════════════╗
469
+ β•‘ DIAGNOSTIC REPORT β•‘
470
+ β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
471
+
472
+ πŸ“‹ DIAGNOSIS
473
+ Disease: {recommendation.disease_name}
474
+ Type: {recommendation.disease_type.capitalize()}
475
+ Confidence: {recommendation.confidence:.0%}
476
+ Urgency: {urgency_emoji.get(recommendation.urgency, 'βšͺ')} {recommendation.urgency.upper()}
477
+
478
+ πŸ“ SYMPTOMS MATCHED
479
+ """
480
+ for symptom in recommendation.symptoms_matched:
481
+ report += f" β€’ {symptom}\n"
482
+
483
+ report += """
484
+ 🌿 ORGANIC TREATMENTS
485
+ """
486
+ for treatment in recommendation.organic_treatments:
487
+ report += f" β€’ {treatment}\n"
488
+
489
+ if include_chemical and recommendation.chemical_treatments:
490
+ report += """
491
+ πŸ§ͺ CHEMICAL TREATMENTS
492
+ """
493
+ for treatment in recommendation.chemical_treatments:
494
+ report += f" β€’ {treatment}\n"
495
+
496
+ report += """
497
+ πŸ›‘οΈ PREVENTIVE MEASURES
498
+ """
499
+ for measure in recommendation.preventive_measures:
500
+ report += f" β€’ {measure}\n"
501
+
502
+ report += f"""
503
+ ⏰ TIMING
504
+ {recommendation.timing}
505
+
506
+ πŸ“ NOTES
507
+ {recommendation.additional_notes}
508
+
509
+ ═══════════════════════════════════════════════════════════════
510
+ """
511
+ return report
512
+
513
+
514
+ if __name__ == "__main__":
515
+ # Test without API
516
+ recommender = TreatmentRecommender(use_llm=False)
517
+
518
+ # Test with symptoms
519
+ result = recommender.get_recommendation(
520
+ symptoms=["brown spots", "yellow halos", "circular lesions"],
521
+ severity="moderate",
522
+ plant_species="tomato"
523
+ )
524
+
525
+ print(recommender.format_report(result))
src/visualization.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Visualization Module for CropDoctor-Semantic
3
+ =============================================
4
+
5
+ This module provides visualization functions for:
6
+ - Segmentation masks overlay
7
+ - Severity heatmaps
8
+ - Diagnostic dashboards
9
+ - Comparison views
10
+ """
11
+
12
+ import numpy as np
13
+ from PIL import Image
14
+ import matplotlib.pyplot as plt
15
+ import matplotlib.patches as mpatches
16
+ from matplotlib.colors import LinearSegmentedColormap
17
+ from typing import List, Optional, Tuple, Union, Dict
18
+ from pathlib import Path
19
+ import logging
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # Color schemes
24
+ SEVERITY_COLORS = {
25
+ 'healthy': '#2ECC71', # Green
26
+ 'mild': '#F1C40F', # Yellow
27
+ 'moderate': '#E67E22', # Orange
28
+ 'severe': '#E74C3C' # Red
29
+ }
30
+
31
+ SYMPTOM_COLORS = [
32
+ '#E74C3C', # Red
33
+ '#9B59B6', # Purple
34
+ '#3498DB', # Blue
35
+ '#E67E22', # Orange
36
+ '#1ABC9C', # Teal
37
+ '#F39C12', # Yellow
38
+ '#D35400', # Dark Orange
39
+ '#8E44AD', # Dark Purple
40
+ ]
41
+
42
+
43
+ def create_diagnostic_visualization(
44
+ image: Union[str, Path, Image.Image, np.ndarray],
45
+ masks: Optional[np.ndarray] = None,
46
+ severity_label: str = "unknown",
47
+ disease_name: str = "Unknown",
48
+ affected_percent: float = 0.0,
49
+ prompt_labels: Optional[List[str]] = None,
50
+ figsize: Tuple[int, int] = (16, 6)
51
+ ) -> plt.Figure:
52
+ """
53
+ Create a comprehensive diagnostic visualization.
54
+
55
+ Args:
56
+ image: Input image
57
+ masks: Segmentation masks array (N, H, W)
58
+ severity_label: Severity classification result
59
+ disease_name: Identified disease name
60
+ affected_percent: Percentage of affected area
61
+ prompt_labels: Labels for each mask
62
+ figsize: Figure size
63
+
64
+ Returns:
65
+ matplotlib Figure object
66
+ """
67
+ # Load image
68
+ if isinstance(image, (str, Path)):
69
+ image = Image.open(image).convert("RGB")
70
+ elif isinstance(image, np.ndarray):
71
+ image = Image.fromarray(image)
72
+
73
+ img_array = np.array(image)
74
+
75
+ # Create figure with subplots
76
+ fig, axes = plt.subplots(1, 3, figsize=figsize)
77
+ fig.suptitle(f'CropDoctor Diagnostic Report', fontsize=14, fontweight='bold')
78
+
79
+ # Panel 1: Original Image
80
+ axes[0].imshow(img_array)
81
+ axes[0].set_title('Original Image', fontsize=12)
82
+ axes[0].axis('off')
83
+
84
+ # Panel 2: Segmentation Overlay
85
+ if masks is not None and len(masks) > 0:
86
+ overlay = create_mask_overlay(img_array, masks, alpha=0.5)
87
+ axes[1].imshow(overlay)
88
+
89
+ # Create legend
90
+ if prompt_labels:
91
+ patches = []
92
+ for i, label in enumerate(prompt_labels[:len(SYMPTOM_COLORS)]):
93
+ color = SYMPTOM_COLORS[i % len(SYMPTOM_COLORS)]
94
+ patches.append(mpatches.Patch(color=color, label=label, alpha=0.7))
95
+ axes[1].legend(handles=patches, loc='upper right', fontsize=8)
96
+ else:
97
+ axes[1].imshow(img_array)
98
+ axes[1].text(0.5, 0.5, 'No disease regions detected',
99
+ transform=axes[1].transAxes, ha='center', va='center',
100
+ fontsize=12, color='green')
101
+
102
+ axes[1].set_title('Disease Detection', fontsize=12)
103
+ axes[1].axis('off')
104
+
105
+ # Panel 3: Diagnostic Summary
106
+ axes[2].axis('off')
107
+
108
+ # Create summary text
109
+ severity_color = SEVERITY_COLORS.get(severity_label.lower(), '#95A5A6')
110
+
111
+ summary_text = f"""
112
+ ╔════════════════════════════════╗
113
+ β•‘ DIAGNOSTIC SUMMARY β•‘
114
+ β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
115
+
116
+ πŸ“‹ Disease: {disease_name}
117
+
118
+ ⚠️ Severity: {severity_label.upper()}
119
+
120
+ πŸ“Š Affected Area: {affected_percent:.1f}%
121
+
122
+ """
123
+
124
+ # Add severity indicator
125
+ axes[2].text(0.5, 0.65, summary_text, transform=axes[2].transAxes,
126
+ fontsize=11, fontfamily='monospace',
127
+ verticalalignment='top', horizontalalignment='center')
128
+
129
+ # Add severity color bar
130
+ severity_bar = plt.Rectangle((0.15, 0.25), 0.7, 0.1,
131
+ facecolor=severity_color,
132
+ edgecolor='black',
133
+ transform=axes[2].transAxes)
134
+ axes[2].add_patch(severity_bar)
135
+ axes[2].text(0.5, 0.30, severity_label.upper(),
136
+ transform=axes[2].transAxes, ha='center', va='center',
137
+ fontsize=12, fontweight='bold', color='white')
138
+
139
+ # Add affected area progress bar
140
+ bar_width = 0.7 * (affected_percent / 100)
141
+ bg_bar = plt.Rectangle((0.15, 0.12), 0.7, 0.06,
142
+ facecolor='#EEEEEE', edgecolor='black',
143
+ transform=axes[2].transAxes)
144
+ progress_bar = plt.Rectangle((0.15, 0.12), max(0.01, bar_width), 0.06,
145
+ facecolor='#E74C3C',
146
+ transform=axes[2].transAxes)
147
+ axes[2].add_patch(bg_bar)
148
+ axes[2].add_patch(progress_bar)
149
+ axes[2].text(0.5, 0.08, f'Affected Area: {affected_percent:.1f}%',
150
+ transform=axes[2].transAxes, ha='center', fontsize=10)
151
+
152
+ plt.tight_layout()
153
+
154
+ return fig
155
+
156
+
157
+ def create_mask_overlay(
158
+ image: np.ndarray,
159
+ masks: np.ndarray,
160
+ alpha: float = 0.5,
161
+ colors: Optional[List[str]] = None
162
+ ) -> np.ndarray:
163
+ """
164
+ Create an overlay of segmentation masks on an image.
165
+
166
+ Args:
167
+ image: RGB image array (H, W, 3)
168
+ masks: Binary masks (N, H, W)
169
+ alpha: Transparency of overlay
170
+ colors: Optional list of colors for masks
171
+
172
+ Returns:
173
+ Image array with mask overlay
174
+ """
175
+ if colors is None:
176
+ colors = SYMPTOM_COLORS
177
+
178
+ # Start with the original image
179
+ overlay = image.copy().astype(np.float32)
180
+
181
+ for i, mask in enumerate(masks):
182
+ if mask.any():
183
+ # Get color for this mask
184
+ color_hex = colors[i % len(colors)]
185
+ color_rgb = hex_to_rgb(color_hex)
186
+
187
+ # Create colored mask
188
+ colored_mask = np.zeros_like(overlay)
189
+ colored_mask[mask] = color_rgb
190
+
191
+ # Blend with overlay
192
+ mask_3d = np.stack([mask] * 3, axis=-1)
193
+ overlay = np.where(
194
+ mask_3d,
195
+ overlay * (1 - alpha) + colored_mask * alpha,
196
+ overlay
197
+ )
198
+
199
+ return overlay.astype(np.uint8)
200
+
201
+
202
+ def create_severity_heatmap(
203
+ image: Union[str, Path, Image.Image, np.ndarray],
204
+ severity_map: np.ndarray,
205
+ figsize: Tuple[int, int] = (12, 5)
206
+ ) -> plt.Figure:
207
+ """
208
+ Create a heatmap showing severity distribution across the image.
209
+
210
+ Args:
211
+ image: Input image
212
+ severity_map: Array of severity values (H, W) with values 0-3
213
+ figsize: Figure size
214
+
215
+ Returns:
216
+ matplotlib Figure object
217
+ """
218
+ # Load image
219
+ if isinstance(image, (str, Path)):
220
+ image = Image.open(image).convert("RGB")
221
+ elif isinstance(image, np.ndarray):
222
+ image = Image.fromarray(image)
223
+
224
+ img_array = np.array(image)
225
+
226
+ # Create custom colormap
227
+ colors = ['#2ECC71', '#F1C40F', '#E67E22', '#E74C3C'] # Green to Red
228
+ cmap = LinearSegmentedColormap.from_list('severity', colors, N=256)
229
+
230
+ fig, axes = plt.subplots(1, 2, figsize=figsize)
231
+
232
+ # Original image
233
+ axes[0].imshow(img_array)
234
+ axes[0].set_title('Original Image')
235
+ axes[0].axis('off')
236
+
237
+ # Heatmap overlay
238
+ axes[1].imshow(img_array)
239
+ heatmap = axes[1].imshow(severity_map, cmap=cmap, alpha=0.6, vmin=0, vmax=3)
240
+ axes[1].set_title('Severity Heatmap')
241
+ axes[1].axis('off')
242
+
243
+ # Add colorbar
244
+ cbar = plt.colorbar(heatmap, ax=axes[1], fraction=0.046, pad=0.04)
245
+ cbar.set_ticks([0, 1, 2, 3])
246
+ cbar.set_ticklabels(['Healthy', 'Mild', 'Moderate', 'Severe'])
247
+
248
+ plt.tight_layout()
249
+
250
+ return fig
251
+
252
+
253
+ def create_comparison_view(
254
+ images: List[Union[str, Path, Image.Image]],
255
+ results: List[Dict],
256
+ cols: int = 4,
257
+ figsize_per_image: Tuple[float, float] = (4, 5)
258
+ ) -> plt.Figure:
259
+ """
260
+ Create a grid comparison view of multiple diagnoses.
261
+
262
+ Args:
263
+ images: List of images
264
+ results: List of diagnostic results (dicts with 'severity_label', 'disease_name', etc.)
265
+ cols: Number of columns in grid
266
+ figsize_per_image: Size per image in the grid
267
+
268
+ Returns:
269
+ matplotlib Figure object
270
+ """
271
+ n_images = len(images)
272
+ rows = (n_images + cols - 1) // cols
273
+
274
+ fig, axes = plt.subplots(
275
+ rows, cols,
276
+ figsize=(figsize_per_image[0] * cols, figsize_per_image[1] * rows)
277
+ )
278
+
279
+ if rows == 1:
280
+ axes = [axes]
281
+ if cols == 1:
282
+ axes = [[ax] for ax in axes]
283
+
284
+ for i, (img, result) in enumerate(zip(images, results)):
285
+ row = i // cols
286
+ col = i % cols
287
+ ax = axes[row][col] if rows > 1 else axes[col]
288
+
289
+ # Load image
290
+ if isinstance(img, (str, Path)):
291
+ img = Image.open(img).convert("RGB")
292
+
293
+ ax.imshow(img)
294
+ ax.axis('off')
295
+
296
+ # Add colored border based on severity
297
+ severity = result.get('severity_label', 'unknown')
298
+ color = SEVERITY_COLORS.get(severity.lower(), '#95A5A6')
299
+
300
+ for spine in ax.spines.values():
301
+ spine.set_edgecolor(color)
302
+ spine.set_linewidth(4)
303
+ spine.set_visible(True)
304
+
305
+ # Add label
306
+ ax.set_title(
307
+ f"{result.get('disease_name', 'Unknown')}\n{severity.upper()}",
308
+ fontsize=10,
309
+ color=color
310
+ )
311
+
312
+ # Hide empty subplots
313
+ for i in range(n_images, rows * cols):
314
+ row = i // cols
315
+ col = i % cols
316
+ ax = axes[row][col] if rows > 1 else axes[col]
317
+ ax.axis('off')
318
+ ax.set_visible(False)
319
+
320
+ plt.tight_layout()
321
+
322
+ return fig
323
+
324
+
325
+ def create_treatment_card(
326
+ result: Dict,
327
+ figsize: Tuple[int, int] = (8, 10)
328
+ ) -> plt.Figure:
329
+ """
330
+ Create a treatment recommendation card.
331
+
332
+ Args:
333
+ result: Diagnostic result dictionary
334
+ figsize: Figure size
335
+
336
+ Returns:
337
+ matplotlib Figure object
338
+ """
339
+ fig, ax = plt.subplots(figsize=figsize)
340
+ ax.axis('off')
341
+
342
+ severity_color = SEVERITY_COLORS.get(
343
+ result.get('severity_label', 'unknown').lower(),
344
+ '#95A5A6'
345
+ )
346
+
347
+ # Title
348
+ ax.text(0.5, 0.95, '🌿 TREATMENT CARD',
349
+ ha='center', va='top', fontsize=16, fontweight='bold',
350
+ transform=ax.transAxes)
351
+
352
+ # Disease info
353
+ disease_text = f"""
354
+ ╔═══════════════════════════════════════════╗
355
+ β•‘ Disease: {result.get('disease_name', 'Unknown'):<32}β•‘
356
+ β•‘ Type: {result.get('disease_type', 'unknown'):<35}β•‘
357
+ β•‘ Severity: {result.get('severity_label', 'unknown').upper():<31}β•‘
358
+ β•‘ Affected Area: {result.get('affected_area_percent', 0):.1f}%{' ' * 25}β•‘
359
+ β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
360
+ """
361
+ ax.text(0.5, 0.85, disease_text,
362
+ ha='center', va='top', fontfamily='monospace', fontsize=10,
363
+ transform=ax.transAxes)
364
+
365
+ # Treatments
366
+ y_pos = 0.60
367
+
368
+ # Organic treatments
369
+ ax.text(0.1, y_pos, '🌱 ORGANIC TREATMENTS', fontweight='bold', fontsize=11,
370
+ transform=ax.transAxes)
371
+ y_pos -= 0.03
372
+
373
+ for treatment in result.get('organic_treatments', [])[:4]:
374
+ ax.text(0.12, y_pos, f'β€’ {treatment[:50]}', fontsize=9,
375
+ transform=ax.transAxes)
376
+ y_pos -= 0.03
377
+
378
+ y_pos -= 0.02
379
+
380
+ # Chemical treatments
381
+ if result.get('chemical_treatments'):
382
+ ax.text(0.1, y_pos, 'πŸ§ͺ CHEMICAL TREATMENTS', fontweight='bold', fontsize=11,
383
+ transform=ax.transAxes)
384
+ y_pos -= 0.03
385
+
386
+ for treatment in result.get('chemical_treatments', [])[:3]:
387
+ ax.text(0.12, y_pos, f'β€’ {treatment[:50]}', fontsize=9,
388
+ transform=ax.transAxes)
389
+ y_pos -= 0.03
390
+
391
+ y_pos -= 0.02
392
+
393
+ # Prevention
394
+ ax.text(0.1, y_pos, 'πŸ›‘οΈ PREVENTION', fontweight='bold', fontsize=11,
395
+ transform=ax.transAxes)
396
+ y_pos -= 0.03
397
+
398
+ for measure in result.get('preventive_measures', [])[:4]:
399
+ ax.text(0.12, y_pos, f'β€’ {measure[:50]}', fontsize=9,
400
+ transform=ax.transAxes)
401
+ y_pos -= 0.03
402
+
403
+ # Timing
404
+ y_pos -= 0.02
405
+ ax.text(0.1, y_pos, f"⏰ TIMING: {result.get('treatment_timing', 'Consult expert')[:60]}",
406
+ fontsize=9, transform=ax.transAxes)
407
+
408
+ # Add border
409
+ rect = plt.Rectangle((0.05, 0.05), 0.9, 0.92,
410
+ fill=False, edgecolor=severity_color, linewidth=3,
411
+ transform=ax.transAxes)
412
+ ax.add_patch(rect)
413
+
414
+ return fig
415
+
416
+
417
+ def hex_to_rgb(hex_color: str) -> List[int]:
418
+ """Convert hex color to RGB."""
419
+ hex_color = hex_color.lstrip('#')
420
+ return [int(hex_color[i:i+2], 16) for i in (0, 2, 4)]
421
+
422
+
423
+ def save_visualization(
424
+ fig: plt.Figure,
425
+ output_path: Union[str, Path],
426
+ dpi: int = 150
427
+ ):
428
+ """Save figure to file."""
429
+ output_path = Path(output_path)
430
+ output_path.parent.mkdir(parents=True, exist_ok=True)
431
+ fig.savefig(output_path, dpi=dpi, bbox_inches='tight', facecolor='white')
432
+ plt.close(fig)
433
+ logger.info(f"Visualization saved to {output_path}")
434
+
435
+
436
+ if __name__ == "__main__":
437
+ # Test visualizations
438
+ import numpy as np
439
+
440
+ # Create test image
441
+ test_img = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
442
+ test_img[:, :, 1] = 139 # Greenish tint
443
+
444
+ # Create test masks
445
+ test_masks = np.zeros((2, 480, 640), dtype=bool)
446
+ test_masks[0, 100:200, 100:200] = True # Square mask
447
+ test_masks[1, 300:400, 400:500] = True # Another square
448
+
449
+ # Test diagnostic visualization
450
+ fig = create_diagnostic_visualization(
451
+ test_img,
452
+ test_masks,
453
+ severity_label="moderate",
454
+ disease_name="Leaf Spot Disease",
455
+ affected_percent=15.5,
456
+ prompt_labels=["brown spots", "yellowing"]
457
+ )
458
+
459
+ save_visualization(fig, "/tmp/test_diagnostic.png")
460
+ print("Test visualization saved to /tmp/test_diagnostic.png")