MaryPazRB commited on
Commit
8006379
·
1 Parent(s): ee102b7

update app: SAM2 vit_b + SAM3 integration

Browse files
Files changed (2) hide show
  1. README.md +13 -3
  2. app.py +165 -113
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
- title: CLR Severity
3
- emoji: 😻
4
  colorFrom: blue
5
  colorTo: green
6
  sdk: gradio
@@ -8,7 +8,17 @@ sdk_version: 6.9.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
- short_description: Coffee Leaf Rust leaf and severity segmentation framework.
12
  ---
13
 
 
 
 
 
 
 
 
 
 
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: CLR Severity Estimator
3
+ emoji:
4
  colorFrom: blue
5
  colorTo: green
6
  sdk: gradio
 
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ short_description: Coffee Leaf Rust pipeline using YOLOv8, SAM2, and SAM3.
12
  ---
13
 
14
+ # ☕ Coffee Leaf Rust (CLR) Severity Estimator
15
+
16
+ This framework processes coffee leaf images to accurately estimate rust severity using a 3-step deep learning pipeline:
17
+
18
+ 1. **Leaf Detection**: **YOLOv8** locates and extracts bounding boxes for all coffee leaves in the image.
19
+ 2. **Instance Segmentation**: **SAM2 (Segment Anything Model)** takes the bounding boxes to create pixel-perfect black-background cutouts of each leaf.
20
+ 3. **Rust Segmentation**: **SAM3** uses a text prompt ("yellow spot") to find and segment the rust lesions on each leaf cutout.
21
+
22
+ The Gradio interface presents a detailed summary table and visualizations for each individual leaf in the image.
23
+
24
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -4,194 +4,246 @@ import gradio as gr
4
  import numpy as np
5
  from PIL import Image
6
  import torch
 
7
  from ultralytics import YOLO
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  ############################################
10
  # Configuration
11
  ############################################
12
 
13
- YOLO_MODEL_PATH = "clr_YOLOV8.pt"
14
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
15
 
16
  ############################################
17
- # Load YOLO Model
18
  ############################################
19
 
20
- print("Loading YOLO model...")
 
 
 
 
 
 
21
 
 
 
 
 
 
 
 
 
22
  try:
23
- yolo_model = YOLO(YOLO_MODEL_PATH)
24
- print("YOLO model loaded.")
 
 
 
25
  except Exception as e:
26
- yolo_model = None
27
- print("YOLO loading error:", e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  ############################################
30
  # Helper Functions
31
  ############################################
32
 
33
- def segment_rust_simple(leaf_img):
34
- """
35
- Simple rust segmentation using HSV color threshold.
36
- Works as fallback when SAM is unavailable.
37
- """
38
-
39
  hsv = cv2.cvtColor(leaf_img, cv2.COLOR_BGR2HSV)
40
-
41
- # Rust-like colors
42
  lower = np.array([10, 80, 80])
43
  upper = np.array([35, 255, 255])
44
-
45
  mask = cv2.inRange(hsv, lower, upper)
46
-
47
  kernel = np.ones((3,3), np.uint8)
48
- mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- return mask
51
 
 
 
52
 
53
- def calculate_leaf_area(leaf_img):
54
- """
55
- Estimate leaf pixels via threshold.
56
- """
57
- gray = cv2.cvtColor(leaf_img, cv2.COLOR_BGR2GRAY)
58
- _, mask = cv2.threshold(gray, 10, 255, cv2.THRESH_BINARY)
59
 
60
- return mask
 
 
 
61
 
 
 
 
 
 
62
 
63
  ############################################
64
- # Main Processing Function
65
  ############################################
66
 
67
  def process_coffee_leaf(image):
68
-
69
  if image is None:
70
- return None, "Upload an image."
71
 
72
  if yolo_model is None:
73
- return image, "YOLO model not loaded."
74
 
75
  image_np = np.array(image)
76
  image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
 
77
 
78
  results = yolo_model(image_cv, verbose=False)
79
-
80
  boxes = results[0].boxes.xyxy.cpu().numpy()
81
 
82
  if len(boxes) == 0:
83
- return image_np, "No leaves detected."
84
 
85
  annotated = image_np.copy()
86
-
87
- severities = []
88
-
89
- h, w = image_cv.shape[:2]
90
 
91
  for i, box in enumerate(boxes):
92
-
93
  x1, y1, x2, y2 = box.astype(int)
94
 
95
- x1, x2 = max(0, x1), min(w, x2)
96
- y1, y2 = max(0, y1), min(h, y2)
97
-
98
- leaf_crop = image_cv[y1:y2, x1:x2]
99
-
100
  if leaf_crop.size == 0:
101
  continue
102
 
103
- ################################
104
- # Leaf mask
105
- ################################
106
 
107
- leaf_mask = calculate_leaf_area(leaf_crop)
108
  leaf_pixels = cv2.countNonZero(leaf_mask)
109
 
110
- ################################
111
- # Rust segmentation
112
- ################################
113
-
114
- rust_mask = segment_rust_simple(leaf_crop)
115
- rust_pixels = cv2.countNonZero(rust_mask)
116
-
117
- ################################
118
- # Severity calculation
119
- ################################
120
-
121
- severity = 0
122
-
123
- if leaf_pixels > 0:
124
- severity = (rust_pixels / leaf_pixels) * 100
125
 
126
- severities.append(f"Leaf {i+1}: {severity:.2f}%")
 
 
127
 
128
- ################################
129
- # Visualization
130
- ################################
131
 
132
- # draw leaf bbox
133
- cv2.rectangle(annotated,(x1,y1),(x2,y2),(0,255,0),2)
134
 
135
- cv2.putText(
136
- annotated,
137
- f"{severity:.1f}%",
138
- (x1,y1-5),
139
- cv2.FONT_HERSHEY_SIMPLEX,
140
- 0.6,
141
- (0,255,0),
142
- 2
143
- )
144
 
145
- # resize rust mask to image coords
146
- rust_mask = cv2.resize(
147
- rust_mask,
148
- (x2-x1, y2-y1),
149
- interpolation=cv2.INTER_NEAREST
150
- )
151
 
152
- overlay = np.zeros_like(annotated)
 
153
 
154
- overlay[y1:y2, x1:x2][rust_mask > 0] = [255,0,0]
155
 
156
- alpha = 0.4
157
- mask_indices = overlay[:,:,0] > 0
158
 
159
- annotated[mask_indices] = (
160
- annotated[mask_indices]*(1-alpha) +
161
- overlay[mask_indices]*alpha
162
- ).astype(np.uint8)
163
 
164
- report = f"Detected {len(boxes)} leaves\n\n"
165
- report += "\n".join(severities)
166
 
167
- return annotated, report
 
 
168
 
 
 
 
169
 
170
- ############################################
171
- # Gradio Interface
172
- ############################################
 
 
173
 
174
- demo = gr.Interface(
175
- fn=process_coffee_leaf,
176
- inputs=gr.Image(type="pil", label="Upload Coffee Leaf Image"),
177
- outputs=[
178
- gr.Image(label="Analyzed Image"),
179
- gr.Textbox(label="Severity Report")
180
- ],
181
- title="☕ Coffee Leaf Rust Severity Estimator",
182
- description="""
183
- Upload a coffee leaf image.
184
- The system detects leaves using YOLOv8 and estimates rust severity by segmenting rust-colored lesions.
185
- """,
186
- )
187
 
 
 
 
 
188
 
189
  ############################################
190
  # Launch
191
  ############################################
192
 
193
  if __name__ == "__main__":
194
- demo.launch(
195
- server_name="0.0.0.0",
196
- server_port=7860
197
- )
 
4
  import numpy as np
5
  from PIL import Image
6
  import torch
7
+ import urllib.request
8
  from ultralytics import YOLO
9
 
10
+ # Try importing SAM2
11
+ try:
12
+ from segment_anything import sam_model_registry, SamPredictor
13
+ SAM2_AVAILABLE = True
14
+ except ImportError:
15
+ SAM2_AVAILABLE = False
16
+ print("SAM2 not available")
17
+
18
+ # Try importing SAM3
19
+ try:
20
+ from sam3.model_builder import build_sam3_image_model
21
+ from sam3.model.sam3_image_processor import Sam3Processor
22
+ SAM3_AVAILABLE = True
23
+ except ImportError:
24
+ SAM3_AVAILABLE = False
25
+ print("SAM3 not available")
26
+
27
  ############################################
28
  # Configuration
29
  ############################################
30
 
 
31
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
32
+ print("Running on:", DEVICE)
33
+
34
+ YOLO_MODEL_PATH = "clr_YOLOV8.pt"
35
+
36
+ # SAM2 vit_b (lighter)
37
+ SAM2_MODEL_TYPE = "vit_b"
38
+ SAM2_CHECKPOINT_PATH = "sam_vit_b_01ec64.pth"
39
 
40
  ############################################
41
+ # Download SAM2 if needed
42
  ############################################
43
 
44
+ if SAM2_AVAILABLE:
45
+ if not os.path.exists(SAM2_CHECKPOINT_PATH):
46
+ print("Downloading SAM2 vit_b checkpoint...")
47
+ urllib.request.urlretrieve(
48
+ "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
49
+ SAM2_CHECKPOINT_PATH
50
+ )
51
 
52
+ ############################################
53
+ # Load Models
54
+ ############################################
55
+
56
+ print("Loading models...")
57
+
58
+ # YOLO
59
+ yolo_model = None
60
  try:
61
+ if os.path.exists(YOLO_MODEL_PATH):
62
+ yolo_model = YOLO(YOLO_MODEL_PATH)
63
+ print("YOLO loaded")
64
+ else:
65
+ print("YOLO model not found")
66
  except Exception as e:
67
+ print("YOLO error:", e)
68
+
69
+ # SAM2
70
+ sam2_predictor = None
71
+ if SAM2_AVAILABLE:
72
+ try:
73
+ sam2 = sam_model_registry[SAM2_MODEL_TYPE](
74
+ checkpoint=SAM2_CHECKPOINT_PATH
75
+ )
76
+ sam2.to(DEVICE)
77
+ sam2_predictor = SamPredictor(sam2)
78
+ print("SAM2 loaded")
79
+ except Exception as e:
80
+ print("SAM2 error:", e)
81
+
82
+ # SAM3 (official, no checkpoint)
83
+ sam3_processor = None
84
+ if SAM3_AVAILABLE:
85
+ try:
86
+ sam3_model = build_sam3_image_model(device=DEVICE)
87
+ sam3_processor = Sam3Processor(sam3_model)
88
+ print("SAM3 loaded")
89
+ except Exception as e:
90
+ print("SAM3 error:", e)
91
 
92
  ############################################
93
  # Helper Functions
94
  ############################################
95
 
96
+ def fallback_segment_rust(leaf_img):
 
 
 
 
 
97
  hsv = cv2.cvtColor(leaf_img, cv2.COLOR_BGR2HSV)
 
 
98
  lower = np.array([10, 80, 80])
99
  upper = np.array([35, 255, 255])
 
100
  mask = cv2.inRange(hsv, lower, upper)
 
101
  kernel = np.ones((3,3), np.uint8)
102
+ return cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
103
+
104
+ def extract_leaf_sam2(image_rgb, box):
105
+ if not sam2_predictor:
106
+ return np.ones((image_rgb.shape[0], image_rgb.shape[1]), dtype=np.uint8) * 255
107
+
108
+ sam2_predictor.set_image(image_rgb)
109
+ masks, _, _ = sam2_predictor.predict(
110
+ box=np.array(box),
111
+ multimask_output=False
112
+ )
113
+ return (masks[0] * 255).astype(np.uint8)
114
+
115
+ def segment_lesions_sam3(image_rgb):
116
+ if not sam3_processor:
117
+ return fallback_segment_rust(cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR))
118
+
119
+ try:
120
+ pil_img = Image.fromarray(image_rgb)
121
+ state = sam3_processor.set_image(pil_img)
122
+ output = sam3_processor.set_text_prompt(
123
+ state=state,
124
+ prompt="yellow spot"
125
+ )
126
 
127
+ masks = output.get("masks", None)
128
 
129
+ if masks is None or len(masks) == 0:
130
+ return np.zeros((image_rgb.shape[0], image_rgb.shape[1]), dtype=np.uint8)
131
 
132
+ combined = None
133
+ for m in masks:
134
+ m_np = m.detach().cpu().numpy()
135
+ m_np = np.squeeze(m_np)
136
+ m_np = (m_np > 0).astype(np.uint8)
 
137
 
138
+ if combined is None:
139
+ combined = m_np
140
+ else:
141
+ combined = np.maximum(combined, m_np)
142
 
143
+ return combined * 255
144
+
145
+ except Exception as e:
146
+ print("SAM3 error:", e)
147
+ return fallback_segment_rust(cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR))
148
 
149
  ############################################
150
+ # Main Function
151
  ############################################
152
 
153
  def process_coffee_leaf(image):
 
154
  if image is None:
155
+ return None, None, [["Upload image", "-", "-"]]
156
 
157
  if yolo_model is None:
158
+ return image, None, [["Error", "YOLO not loaded", "-"]]
159
 
160
  image_np = np.array(image)
161
  image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
162
+ h, w = image_cv.shape[:2]
163
 
164
  results = yolo_model(image_cv, verbose=False)
 
165
  boxes = results[0].boxes.xyxy.cpu().numpy()
166
 
167
  if len(boxes) == 0:
168
+ return image_np, None, [["No leaves detected", "-", "-"]]
169
 
170
  annotated = image_np.copy()
171
+ gallery = []
172
+ table = []
 
 
173
 
174
  for i, box in enumerate(boxes):
 
175
  x1, y1, x2, y2 = box.astype(int)
176
 
177
+ leaf_crop = image_np[y1:y2, x1:x2]
 
 
 
 
178
  if leaf_crop.size == 0:
179
  continue
180
 
181
+ # SAM2 leaf mask
182
+ leaf_mask_full = extract_leaf_sam2(image_np, box)
183
+ leaf_mask = leaf_mask_full[y1:y2, x1:x2]
184
 
 
185
  leaf_pixels = cv2.countNonZero(leaf_mask)
186
 
187
+ # Cutout
188
+ cutout = np.zeros_like(leaf_crop)
189
+ cutout[leaf_mask > 0] = leaf_crop[leaf_mask > 0]
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
+ # SAM3 lesions
192
+ rust_mask = segment_lesions_sam3(cutout)
193
+ rust_mask = cv2.bitwise_and(rust_mask, leaf_mask)
194
 
195
+ rust_pixels = cv2.countNonZero(rust_mask)
 
 
196
 
197
+ severity = (rust_pixels / leaf_pixels) * 100 if leaf_pixels > 0 else 0
 
198
 
199
+ # Draw bbox
200
+ cv2.rectangle(annotated, (x1,y1), (x2,y2), (0,255,0), 2)
201
+ cv2.putText(annotated, f"{severity:.1f}%", (x1,y1-5),
202
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)
 
 
 
 
 
203
 
204
+ # Overlay
205
+ overlay = cutout.copy()
206
+ overlay[rust_mask > 0] = [128, 0, 128]
 
 
 
207
 
208
+ gallery.append((cutout, f"Leaf {i+1}"))
209
+ gallery.append((overlay, f"{severity:.1f}%"))
210
 
211
+ table.append([str(i+1), f"{severity:.1f}%", f"{100-severity:.1f}%"])
212
 
213
+ return annotated, gallery, table
 
214
 
215
+ ############################################
216
+ # UI
217
+ ############################################
 
218
 
219
+ with gr.Blocks() as demo:
220
+ gr.Markdown("# Coffee Leaf Rust Severity Estimator")
221
 
222
+ image_input = gr.Image(type="pil")
223
+ submit = gr.Button("Run")
224
+ clear = gr.Button("Clear")
225
 
226
+ output_img = gr.Image()
227
+ gallery = gr.Gallery(columns=2)
228
+ table = gr.Dataframe(headers=["Leaf", "Severity", "Healthy"])
229
 
230
+ submit.click(
231
+ process_coffee_leaf,
232
+ inputs=image_input,
233
+ outputs=[output_img, gallery, table]
234
+ )
235
 
236
+ def clear_all():
237
+ return None, None, None, None
 
 
 
 
 
 
 
 
 
 
 
238
 
239
+ clear.click(
240
+ clear_all,
241
+ outputs=[image_input, output_img, gallery, table]
242
+ )
243
 
244
  ############################################
245
  # Launch
246
  ############################################
247
 
248
  if __name__ == "__main__":
249
+ demo.launch()