davidsv commited on
Commit
219d945
·
1 Parent(s): f8eb07d

Add example images and update SAM2 settings

Browse files
app.py CHANGED
@@ -166,11 +166,6 @@ Upload a plant leaf image to detect disease regions using AI.
166
  )
167
 
168
  with gr.Row():
169
- leaf_seg_checkbox = gr.Checkbox(
170
- value=True,
171
- label="Isolate leaf (SAM2)",
172
- info="Segment leaf before detection to reduce false positives"
173
- )
174
  confidence_slider = gr.Slider(
175
  minimum=0.1,
176
  maximum=0.9,
@@ -179,6 +174,12 @@ Upload a plant leaf image to detect disease regions using AI.
179
  label="Confidence Threshold"
180
  )
181
 
 
 
 
 
 
 
182
  detect_btn = gr.Button("Detect Disease", variant="primary", size="lg")
183
 
184
  with gr.Column():
@@ -192,6 +193,24 @@ Upload a plant leaf image to detect disease regions using AI.
192
  value="Upload an image and click 'Detect Disease' to see results."
193
  )
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  # Event handler
196
  detect_btn.click(
197
  fn=detect_disease,
 
166
  )
167
 
168
  with gr.Row():
 
 
 
 
 
169
  confidence_slider = gr.Slider(
170
  minimum=0.1,
171
  maximum=0.9,
 
174
  label="Confidence Threshold"
175
  )
176
 
177
+ gr.Markdown("**SAM2 Segmentation**: Enable to improve detection accuracy by isolating the leaf and generating precise disease masks instead of bounding boxes.")
178
+ leaf_seg_checkbox = gr.Checkbox(
179
+ value=False,
180
+ label="Enable SAM2 segmentation"
181
+ )
182
+
183
  detect_btn = gr.Button("Detect Disease", variant="primary", size="lg")
184
 
185
  with gr.Column():
 
193
  value="Upload an image and click 'Detect Disease' to see results."
194
  )
195
 
196
+ # Example images
197
+ gr.Markdown("### Example Images")
198
+ gr.Markdown("Click an example image to load it:")
199
+
200
+ example_images = [
201
+ ["examples/img1.png"],
202
+ ["examples/img2.png"],
203
+ ["examples/img3.png"],
204
+ ["examples/img4.png"],
205
+ ["examples/img5.png"],
206
+ ]
207
+
208
+ gr.Examples(
209
+ examples=example_images,
210
+ inputs=[input_image],
211
+ label="Select an example"
212
+ )
213
+
214
  # Event handler
215
  detect_btn.click(
216
  fn=detect_disease,
examples/img1.png ADDED
examples/img2.png ADDED
examples/img3.png ADDED
examples/img4.png ADDED
examples/img5.png ADDED
models/severity_classifier/best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b8a09bf95bddab9421204e9ca9c2b085e5b941e9532ddccef526801b52eeeaa
3
+ size 48579306
src/leaf_segmenter.py CHANGED
@@ -7,7 +7,7 @@ from backgrounds before disease detection.
7
 
8
  import numpy as np
9
  from PIL import Image
10
- from typing import Optional, Tuple
11
  import torch
12
 
13
 
@@ -50,6 +50,7 @@ class SAM2LeafSegmenter:
50
 
51
  self.model = None
52
  self.predictor = None
 
53
 
54
  def load_model(self):
55
  """Load SAM2 model."""
@@ -68,6 +69,27 @@ class SAM2LeafSegmenter:
68
  self.predictor = SAM2ImagePredictor(self.model)
69
  print("SAM2 model loaded.")
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def segment_leaf(
72
  self,
73
  image: Image.Image,
 
7
 
8
  import numpy as np
9
  from PIL import Image
10
+ from typing import Optional, Tuple, List
11
  import torch
12
 
13
 
 
50
 
51
  self.model = None
52
  self.predictor = None
53
+ self.mask_generator = None
54
 
55
  def load_model(self):
56
  """Load SAM2 model."""
 
69
  self.predictor = SAM2ImagePredictor(self.model)
70
  print("SAM2 model loaded.")
71
 
72
+ def load_mask_generator(self):
73
+ """Load SAM2 automatic mask generator for multi-object segmentation."""
74
+ self.load_model()
75
+
76
+ if self.mask_generator is not None:
77
+ return
78
+
79
+ from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
80
+
81
+ print("Initializing SAM2 automatic mask generator...")
82
+ self.mask_generator = SAM2AutomaticMaskGenerator(
83
+ model=self.model,
84
+ points_per_side=32,
85
+ points_per_batch=64,
86
+ pred_iou_thresh=0.7,
87
+ stability_score_thresh=0.92,
88
+ crop_n_layers=1,
89
+ min_mask_region_area=500,
90
+ )
91
+ print("SAM2 mask generator ready.")
92
+
93
  def segment_leaf(
94
  self,
95
  image: Image.Image,