Spaces:
Sleeping
Sleeping
davidsv
commited on
Commit
·
219d945
1
Parent(s):
f8eb07d
Add example images and update SAM2 settings
Browse files- app.py +24 -5
- examples/img1.png +0 -0
- examples/img2.png +0 -0
- examples/img3.png +0 -0
- examples/img4.png +0 -0
- examples/img5.png +0 -0
- models/severity_classifier/best.pt +3 -0
- src/leaf_segmenter.py +23 -1
app.py
CHANGED
|
@@ -166,11 +166,6 @@ Upload a plant leaf image to detect disease regions using AI.
|
|
| 166 |
)
|
| 167 |
|
| 168 |
with gr.Row():
|
| 169 |
-
leaf_seg_checkbox = gr.Checkbox(
|
| 170 |
-
value=True,
|
| 171 |
-
label="Isolate leaf (SAM2)",
|
| 172 |
-
info="Segment leaf before detection to reduce false positives"
|
| 173 |
-
)
|
| 174 |
confidence_slider = gr.Slider(
|
| 175 |
minimum=0.1,
|
| 176 |
maximum=0.9,
|
|
@@ -179,6 +174,12 @@ Upload a plant leaf image to detect disease regions using AI.
|
|
| 179 |
label="Confidence Threshold"
|
| 180 |
)
|
| 181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
detect_btn = gr.Button("Detect Disease", variant="primary", size="lg")
|
| 183 |
|
| 184 |
with gr.Column():
|
|
@@ -192,6 +193,24 @@ Upload a plant leaf image to detect disease regions using AI.
|
|
| 192 |
value="Upload an image and click 'Detect Disease' to see results."
|
| 193 |
)
|
| 194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
# Event handler
|
| 196 |
detect_btn.click(
|
| 197 |
fn=detect_disease,
|
|
|
|
| 166 |
)
|
| 167 |
|
| 168 |
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
confidence_slider = gr.Slider(
|
| 170 |
minimum=0.1,
|
| 171 |
maximum=0.9,
|
|
|
|
| 174 |
label="Confidence Threshold"
|
| 175 |
)
|
| 176 |
|
| 177 |
+
gr.Markdown("**SAM2 Segmentation**: Enable to improve detection accuracy by isolating the leaf and generating precise disease masks instead of bounding boxes.")
|
| 178 |
+
leaf_seg_checkbox = gr.Checkbox(
|
| 179 |
+
value=False,
|
| 180 |
+
label="Enable SAM2 segmentation"
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
detect_btn = gr.Button("Detect Disease", variant="primary", size="lg")
|
| 184 |
|
| 185 |
with gr.Column():
|
|
|
|
| 193 |
value="Upload an image and click 'Detect Disease' to see results."
|
| 194 |
)
|
| 195 |
|
| 196 |
+
# Example images
|
| 197 |
+
gr.Markdown("### Example Images")
|
| 198 |
+
gr.Markdown("Click an example image to load it:")
|
| 199 |
+
|
| 200 |
+
example_images = [
|
| 201 |
+
["examples/img1.png"],
|
| 202 |
+
["examples/img2.png"],
|
| 203 |
+
["examples/img3.png"],
|
| 204 |
+
["examples/img4.png"],
|
| 205 |
+
["examples/img5.png"],
|
| 206 |
+
]
|
| 207 |
+
|
| 208 |
+
gr.Examples(
|
| 209 |
+
examples=example_images,
|
| 210 |
+
inputs=[input_image],
|
| 211 |
+
label="Select an example"
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
# Event handler
|
| 215 |
detect_btn.click(
|
| 216 |
fn=detect_disease,
|
examples/img1.png
ADDED
|
examples/img2.png
ADDED
|
examples/img3.png
ADDED
|
examples/img4.png
ADDED
|
examples/img5.png
ADDED
|
models/severity_classifier/best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8b8a09bf95bddab9421204e9ca9c2b085e5b941e9532ddccef526801b52eeeaa
|
| 3 |
+
size 48579306
|
src/leaf_segmenter.py
CHANGED
|
@@ -7,7 +7,7 @@ from backgrounds before disease detection.
|
|
| 7 |
|
| 8 |
import numpy as np
|
| 9 |
from PIL import Image
|
| 10 |
-
from typing import Optional, Tuple
|
| 11 |
import torch
|
| 12 |
|
| 13 |
|
|
@@ -50,6 +50,7 @@ class SAM2LeafSegmenter:
|
|
| 50 |
|
| 51 |
self.model = None
|
| 52 |
self.predictor = None
|
|
|
|
| 53 |
|
| 54 |
def load_model(self):
|
| 55 |
"""Load SAM2 model."""
|
|
@@ -68,6 +69,27 @@ class SAM2LeafSegmenter:
|
|
| 68 |
self.predictor = SAM2ImagePredictor(self.model)
|
| 69 |
print("SAM2 model loaded.")
|
| 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
def segment_leaf(
|
| 72 |
self,
|
| 73 |
image: Image.Image,
|
|
|
|
| 7 |
|
| 8 |
import numpy as np
|
| 9 |
from PIL import Image
|
| 10 |
+
from typing import Optional, Tuple, List
|
| 11 |
import torch
|
| 12 |
|
| 13 |
|
|
|
|
| 50 |
|
| 51 |
self.model = None
|
| 52 |
self.predictor = None
|
| 53 |
+
self.mask_generator = None
|
| 54 |
|
| 55 |
def load_model(self):
|
| 56 |
"""Load SAM2 model."""
|
|
|
|
| 69 |
self.predictor = SAM2ImagePredictor(self.model)
|
| 70 |
print("SAM2 model loaded.")
|
| 71 |
|
| 72 |
+
def load_mask_generator(self):
|
| 73 |
+
"""Load SAM2 automatic mask generator for multi-object segmentation."""
|
| 74 |
+
self.load_model()
|
| 75 |
+
|
| 76 |
+
if self.mask_generator is not None:
|
| 77 |
+
return
|
| 78 |
+
|
| 79 |
+
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
| 80 |
+
|
| 81 |
+
print("Initializing SAM2 automatic mask generator...")
|
| 82 |
+
self.mask_generator = SAM2AutomaticMaskGenerator(
|
| 83 |
+
model=self.model,
|
| 84 |
+
points_per_side=32,
|
| 85 |
+
points_per_batch=64,
|
| 86 |
+
pred_iou_thresh=0.7,
|
| 87 |
+
stability_score_thresh=0.92,
|
| 88 |
+
crop_n_layers=1,
|
| 89 |
+
min_mask_region_area=500,
|
| 90 |
+
)
|
| 91 |
+
print("SAM2 mask generator ready.")
|
| 92 |
+
|
| 93 |
def segment_leaf(
|
| 94 |
self,
|
| 95 |
image: Image.Image,
|