cropscan-space / app.py
davidsv
Add 5 example images
586fa01
#!/usr/bin/env python3
"""
CropScan - Plant Disease Detection
Hugging Face Space Demo
"""
import gradio as gr
import numpy as np
from PIL import Image
from pathlib import Path
import torch
# Model paths
MODEL_DIR = Path("models")
RFDETR_CHECKPOINT = MODEL_DIR / "rfdetr" / "checkpoint_best_total.pth"
SAM2_CHECKPOINT = MODEL_DIR / "sam2" / "sam2.1_hiera_small.pt"
# Lazy loaded components
segmenter = None
leaf_segmenter = None
# Clean CSS - green theme, no effects
CUSTOM_CSS = """
.gradio-container {
background: #0d1117 !important;
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
}
.main-header {
text-align: center;
padding: 1.5rem 0;
border-bottom: 1px solid #21262d;
margin-bottom: 1.5rem;
}
.main-header h1 {
font-size: 2rem;
font-weight: 600;
color: #3fb950;
margin: 0;
}
.main-header p {
color: #8b949e;
margin-top: 0.25rem;
font-size: 0.9rem;
}
.section-label {
color: #c9d1d9;
font-size: 0.875rem;
font-weight: 500;
margin-bottom: 0.75rem;
}
.result-box {
background: #161b22;
border: 1px solid #30363d;
border-radius: 8px;
padding: 1rem;
color: #c9d1d9;
}
.healthy-status {
background: #0d1117;
border: 1px solid #238636;
border-radius: 6px;
padding: 1rem;
margin-bottom: 1rem;
}
.healthy-status h4 {
color: #3fb950;
margin: 0 0 0.25rem 0;
font-size: 1rem;
}
.warning-status {
background: #0d1117;
border: 1px solid #9e6a03;
border-radius: 6px;
padding: 1rem;
margin-bottom: 1rem;
}
.warning-status h4 {
color: #d29922;
margin: 0 0 0.25rem 0;
font-size: 1rem;
}
.danger-status {
background: #0d1117;
border: 1px solid #da3633;
border-radius: 6px;
padding: 1rem;
margin-bottom: 1rem;
}
.danger-status h4 {
color: #f85149;
margin: 0 0 0.25rem 0;
font-size: 1rem;
}
button.primary {
background: #238636 !important;
border: none !important;
}
button.primary:hover {
background: #2ea043 !important;
}
input[type="range"] {
accent-color: #238636 !important;
}
input[type="checkbox"] {
accent-color: #238636 !important;
}
.footer-text {
text-align: center;
color: #484f58;
font-size: 0.8rem;
padding: 1rem;
border-top: 1px solid #21262d;
margin-top: 1rem;
}
.examples-title {
color: #3fb950 !important;
font-size: 1.1rem !important;
font-weight: 600 !important;
margin-bottom: 1rem !important;
}
.gallery img {
height: 120px !important;
width: auto !important;
object-fit: cover !important;
}
footer {
display: none !important;
}
"""
def load_models():
"""Load models on first use."""
global segmenter, leaf_segmenter
if segmenter is not None:
return
print("Loading models...")
from src.sam3_segmentation import RFDETRSegmenter
from src.leaf_segmenter import SAM2LeafSegmenter
segmenter = RFDETRSegmenter(
checkpoint_path=str(RFDETR_CHECKPOINT),
model_size="medium"
)
leaf_segmenter = SAM2LeafSegmenter(
checkpoint_path=str(SAM2_CHECKPOINT)
)
print("Models loaded!")
def get_care_recommendations(num_detections: int, affected_percent: float) -> str:
"""Generate care recommendations based on detection results."""
if num_detections == 0:
return """<div class="healthy-status">
<h4>Healthy</h4>
<p style="color: #8b949e; margin: 0;">No disease symptoms detected.</p>
</div>
**Care tips:**
- Continue regular watering
- Ensure adequate sunlight
- Monitor for changes
"""
if affected_percent < 10:
severity = "Low"
status_class = "warning-status"
elif affected_percent < 30:
severity = "Moderate"
status_class = "warning-status"
else:
severity = "High"
status_class = "danger-status"
return f"""<div class="{status_class}">
<h4>Disease Detected - {severity}</h4>
<p style="color: #8b949e; margin: 0;">{affected_percent:.1f}% affected | {num_detections} region(s)</p>
</div>
**Recommended actions:**
1. **Isolate** - Separate from healthy plants
2. **Remove affected leaves** - Prune with sterilized tools
3. **Treatment**
- Copper-based fungicide
- Neem oil spray
- Improve air circulation
4. **Monitor** - Check daily for 1-2 weeks
"""
def detect_disease(
image: np.ndarray,
use_leaf_segmentation: bool = True,
confidence_threshold: float = 0.3
) -> tuple:
"""Detect plant diseases in an image."""
if image is None:
return None, "Upload an image to start."
load_models()
pil_image = Image.fromarray(image)
original_image = image.copy()
segmented_image = None
leaf_mask = None
if use_leaf_segmentation:
segmented_pil, leaf_mask = leaf_segmenter.auto_segment_leaf(
pil_image, return_mask=True
)
segmented_image = np.array(segmented_pil)
detection_input = segmented_pil
else:
detection_input = pil_image
prompts = ["diseased plant tissue", "leaf spot", "disease symptom"]
seg_result = segmenter.segment_with_concepts(
detection_input,
prompts,
confidence_threshold=confidence_threshold
)
num_detections = len(seg_result.boxes)
if num_detections > 0:
refined_masks = leaf_segmenter.refine_boxes_to_masks(
detection_input,
seg_result.boxes
)
else:
refined_masks = np.zeros((0, image.shape[0], image.shape[1]), dtype=bool)
from src.visualization import create_mask_overlay
if use_leaf_segmentation and segmented_image is not None:
base_image = segmented_image
else:
base_image = original_image
if num_detections > 0:
annotated = create_mask_overlay(base_image, refined_masks, alpha=0.5)
else:
annotated = base_image
affected_percent = 0
if num_detections > 0:
total_mask = np.zeros(refined_masks[0].shape, dtype=bool)
for mask in refined_masks:
total_mask |= mask
if use_leaf_segmentation and leaf_mask is not None:
affected_percent = (total_mask & leaf_mask).sum() / max(leaf_mask.sum(), 1) * 100
else:
affected_percent = total_mask.sum() / (image.shape[0] * image.shape[1]) * 100
recommendations = get_care_recommendations(num_detections, affected_percent)
return annotated, recommendations
def create_demo():
"""Create Gradio interface."""
with gr.Blocks(
title="CropScan",
css=CUSTOM_CSS,
theme=gr.themes.Base(
primary_hue="green",
neutral_hue="gray",
).set(
body_background_fill="#0d1117",
block_background_fill="#161b22",
block_border_width="1px",
block_border_color="#30363d",
input_background_fill="#0d1117",
button_primary_background_fill="#238636",
)
) as demo:
gr.HTML("""
<div class="main-header">
<h1>CropScan</h1>
<p>Plant disease detection</p>
</div>
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="Upload",
type="numpy",
height=350,
sources=["upload", "webcam"]
)
confidence_slider = gr.Slider(
minimum=0.1,
maximum=0.9,
value=0.3,
step=0.05,
label="Sensitivity"
)
leaf_seg_checkbox = gr.Checkbox(
value=False,
label="SAM2 precision mode"
)
detect_btn = gr.Button("Scan", variant="primary", size="lg")
with gr.Column():
output_image = gr.Image(
label="Result",
type="numpy",
height=350
)
detection_info = gr.Markdown()
gr.HTML('<p class="examples-title">Examples - click to load</p>')
example_images = [
["img1.jpg"],
["img2.jpg"],
["img3.jpg"],
["img4.jpg"],
["img5.jpg"],
]
gr.Examples(
examples=example_images,
inputs=[input_image],
examples_per_page=5,
)
detect_btn.click(
fn=detect_disease,
inputs=[input_image, leaf_seg_checkbox, confidence_slider],
outputs=[output_image, detection_info]
)
gr.HTML('<div class="footer-text">RF-DETR + SAM2</div>')
return demo
if __name__ == "__main__":
demo = create_demo()
demo.launch()