colomboMk commited on
Commit
125c303
·
verified ·
1 Parent(s): 7b14232

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +407 -222
app.py CHANGED
@@ -1,76 +1,143 @@
1
  import os
 
 
 
 
 
 
 
 
2
  import cv2
 
 
3
  import numpy as np
4
  import gradio as gr
5
 
6
  from sahi import AutoDetectionModel
7
  from sahi.predict import get_sliced_prediction
8
 
9
- # Prova a importare ultralytics per il modello di segmentazione nativo (senza SAHI)
10
  try:
11
  from ultralytics import YOLO
12
- _ULTRALYTICS_AVAILABLE = True
13
  except Exception:
14
- _ULTRALYTICS_AVAILABLE = False
15
-
16
- # Soglia massima consentita per il lato della bbox (in pixel) per il modello con SAHI
17
- MAX_SIDE_PX = 70
18
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- def _draw_boxes_rgb(image_rgb: np.ndarray, result, target_class: str):
21
  """
22
- Disegna solo le bbox sul frame RGB (niente etichette testuali).
23
- - Evidenzia in rosso la classe target
24
- - Le altre classi in verde
25
- - Scarta le bbox con lato (max tra width e height) > MAX_SIDE_PX
26
- Restituisce (immagine_annotata_RGB, counts_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  """
28
- # Garantisci 3 canali
29
- if image_rgb.ndim == 2:
30
- image_rgb = cv2.cvtColor(image_rgb, cv2.COLOR_GRAY2RGB)
31
- elif image_rgb.shape[2] == 4:
32
- image_rgb = cv2.cvtColor(image_rgb, cv2.COLOR_RGBA2RGB)
33
-
34
  H, W = image_rgb.shape[:2]
 
 
35
 
36
- # OpenCV disegna in BGR
37
- vis_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
38
  target_count = 0
39
  total_count = 0
40
-
41
- object_predictions = getattr(result, "object_prediction_list", []) or []
42
 
43
  for item in object_predictions:
44
- # bbox
45
  try:
46
  x1, y1, x2, y2 = map(int, item.bbox.to_xyxy())
47
  except Exception:
48
  x1, y1 = int(getattr(item.bbox, "minx", 0)), int(getattr(item.bbox, "miny", 0))
49
  x2, y2 = int(getattr(item.bbox, "maxx", 0)), int(getattr(item.bbox, "maxy", 0))
50
 
51
- # Clamp ai bordi immagine
52
- x1 = max(0, min(x1, W - 1))
53
- y1 = max(0, min(y1, H - 1))
54
- x2 = max(0, min(x2, W - 1))
55
- y2 = max(0, min(y2, H - 1))
56
-
57
- # Normalizza coordinate in caso invertite
58
- if x2 < x1:
59
- x1, x2 = x2, x1
60
- if y2 < y1:
61
- y1, y2 = y2, y1
62
 
63
- # Scarta bbox non valide
64
  w = max(0, x2 - x1)
65
  h = max(0, y2 - y1)
66
  if w == 0 or h == 0:
67
  continue
68
-
69
- # Scarta le bbox con lato maggiore della soglia
70
  if max(w, h) > MAX_SIDE_PX:
71
  continue
72
 
73
- # Scarta bbox con area non positiva (per sicurezza)
74
  area = getattr(item.bbox, "area", w * h)
75
  try:
76
  area_val = float(area() if callable(area) else area)
@@ -80,60 +147,58 @@ def _draw_boxes_rgb(image_rgb: np.ndarray, result, target_class: str):
80
  continue
81
 
82
  cls = getattr(item.category, "name", "unknown")
83
- is_target = (cls == target_class)
 
 
 
84
 
85
- color_bgr = (0, 0, 255) if is_target else (0, 200, 0) # rosso per target, verde per altre
86
- cv2.rectangle(vis_bgr, (x1, y1), (x2, y2), color_bgr, 2)
87
- # Nessuna label testuale
88
 
89
  total_count += 1
90
  if is_target:
91
  target_count += 1
92
 
93
- vis_rgb = cv2.cvtColor(vis_bgr, cv2.COLOR_BGR2RGB)
94
- counts_text = f"target='{target_class}': {target_count} | totale: {total_count}"
95
- return vis_rgb, counts_text
96
-
 
 
 
97
 
98
- def _draw_segmentation_masks_rgb(image_rgb: np.ndarray, ulty_result, target_class: str, alpha: float = 0.45):
99
  """
100
- Disegna le maschere di segmentazione (niente etichette testuali).
101
- - Evidenzia in rosso la classe target
102
- - Le altre classi in verde
103
- - Restituisce (immagine_annotata_RGB, counts_text)
104
  """
105
- if image_rgb.ndim == 2:
106
- image_rgb = cv2.cvtColor(image_rgb, cv2.COLOR_GRAY2RGB)
107
- elif image_rgb.shape[2] == 4:
108
- image_rgb = cv2.cvtColor(image_rgb, cv2.COLOR_RGBA2RGB)
109
-
110
- vis_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
111
 
112
- # Estrarre info dal risultato Ultralytics
113
- r = ulty_result
114
  names = getattr(r, "names", None)
115
  boxes = getattr(r, "boxes", None)
116
  masks = getattr(r, "masks", None)
117
 
118
  if boxes is None or len(boxes) == 0:
119
- # Nessun oggetto
120
- return cv2.cvtColor(vis_bgr, cv2.COLOR_BGR2RGB), f"target='{target_class}': 0 | totale: 0"
121
 
122
- # Numero di istanze
123
  N = len(boxes)
124
-
125
- # Prepara maschere (se presenti)
126
  mask_data = None
127
  if masks is not None and getattr(masks, "data", None) is not None:
128
  try:
129
- mask_data = masks.data # torch.Tensor [N, H, W]
130
  except Exception:
131
  mask_data = None
132
 
133
  target_count = 0
134
  total_count = 0
 
135
 
136
- # Loop istanze
137
  for i in range(N):
138
  try:
139
  cls_idx = int(boxes.cls[i].item())
@@ -143,42 +208,44 @@ def _draw_segmentation_masks_rgb(image_rgb: np.ndarray, ulty_result, target_clas
143
  if isinstance(names, dict):
144
  cls_name = names.get(cls_idx, cls_name)
145
 
146
- is_target = (cls_name == target_class)
147
-
148
- color_bgr = (0, 0, 255) if is_target else (0, 200, 0) # rosso per target, verde per altre
149
 
150
- # Disegna mask se disponibile
151
  if mask_data is not None and i < len(mask_data):
152
  try:
153
  m = mask_data[i]
154
  m = m.detach().cpu().numpy()
155
- m = (m > 0.5).astype(np.uint8) # binarizza
156
- # Assicurare dimensioni identiche a immagine
157
- if m.shape[:2] != vis_bgr.shape[:2]:
158
- m = cv2.resize(m, (vis_bgr.shape[1], vis_bgr.shape[0]), interpolation=cv2.INTER_NEAREST)
159
 
160
- # Overlay colore
161
- overlay = np.zeros_like(vis_bgr, dtype=np.uint8)
162
- overlay[m.astype(bool)] = color_bgr
163
- vis_bgr = cv2.addWeighted(overlay, alpha, vis_bgr, 1 - alpha, 0)
164
 
165
- # Contorno
166
  cnts, _ = cv2.findContours(m, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
167
- cv2.drawContours(vis_bgr, cnts, -1, color_bgr, 2)
 
 
168
  except Exception:
169
- # fallback: disegna il bbox
170
  try:
171
  xyxy = boxes.xyxy[i].detach().cpu().numpy().astype(int)
172
  x1, y1, x2, y2 = map(int, xyxy)
173
- cv2.rectangle(vis_bgr, (x1, y1), (x2, y2), color_bgr, 2)
 
174
  except Exception:
175
  pass
176
  else:
177
- # Nessuna mask: disegna solo bbox
178
  try:
179
  xyxy = boxes.xyxy[i].detach().cpu().numpy().astype(int)
180
  x1, y1, x2, y2 = map(int, xyxy)
181
- cv2.rectangle(vis_bgr, (x1, y1), (x2, y2), color_bgr, 2)
 
182
  except Exception:
183
  pass
184
 
@@ -186,75 +253,42 @@ def _draw_segmentation_masks_rgb(image_rgb: np.ndarray, ulty_result, target_clas
186
  if is_target:
187
  target_count += 1
188
 
189
- vis_rgb = cv2.cvtColor(vis_bgr, cv2.COLOR_BGR2RGB)
190
- counts_text = f"target='{target_class}': {target_count} | totale: {total_count}"
191
- return vis_rgb, counts_text
192
-
193
-
194
- def infer_two_models(
195
- image: np.ndarray,
196
- weights_det_path: str,
197
- conf_det: float,
198
- slice_h: int,
199
- slice_w: int,
200
- overlap_h: float,
201
- overlap_w: float,
202
- device: str,
203
- target_class: str,
204
- weights_seg_path: str,
205
- conf_seg: float,
206
- ):
207
  """
208
- Esegue inferenza su una singola immagine con due modelli:
209
- - Modello A (Detection via SAHI): usa pesi YOLOv11 segment come detection, disegna solo bbox, filtra box con lato > MAX_SIDE_PX
210
- - Modello B (Segmentation nativo YOLO): nessun SAHI, disegna solo maschere (niente etichette)
211
- Restituisce 4 output: (img_det, counts_det, img_seg, counts_seg)
 
212
  """
213
- if image is None:
214
- raise gr.Error("Devi caricare un'immagine.")
215
-
216
- if not weights_det_path or not os.path.exists(weights_det_path):
217
- raise gr.Error(f"File pesi (Detection/SAHI) non trovato: {weights_det_path}")
218
-
219
- if not weights_seg_path or not os.path.exists(weights_seg_path):
220
- raise gr.Error(f"File pesi (Segmentation) non trovato: {weights_seg_path}")
221
-
222
- if not _ULTRALYTICS_AVAILABLE:
223
- raise gr.Error("Ultralytics non è installato per il modello di segmentazione. Installa con: pip install ultralytics")
224
-
225
- image_rgb = image.copy()
226
- model_type = "yolov11"
227
-
228
- # Scelta automatica device se 'auto'
229
- chosen_device = device
230
- if device == "auto":
231
- try:
232
- import torch
233
- chosen_device = "cuda:0" if torch.cuda.is_available() else "cpu"
234
- except Exception:
235
- chosen_device = "cpu"
236
-
237
- # =========================
238
- # Modello A: Detection con SAHI (boxes only)
239
- # =========================
240
- try:
241
- detection_model = AutoDetectionModel.from_pretrained(
242
- model_type=model_type,
243
- model_path=weights_det_path,
244
- confidence_threshold=conf_det,
245
- device=chosen_device,
246
- )
247
- except Exception:
248
- detection_model = AutoDetectionModel.from_pretrained(
249
- model_type=model_type,
250
- model_path=weights_det_path,
251
- confidence_threshold=conf_det,
252
- device="cpu",
253
- )
254
-
255
- sahi_result = get_sliced_prediction(
256
  image_rgb,
257
- detection_model,
258
  slice_height=int(slice_h),
259
  slice_width=int(slice_w),
260
  overlap_height_ratio=float(overlap_h),
@@ -263,106 +297,257 @@ def infer_two_models(
263
  verbose=0,
264
  )
265
 
266
- det_vis_rgb, det_counts_text = _draw_boxes_rgb(image_rgb, sahi_result, target_class)
267
 
268
- # =========================
269
- # Modello B: YOLO Segmentation nativo (no SAHI)
270
- # =========================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  try:
272
- seg_model = YOLO(weights_seg_path)
273
- # Nota: Ultralytics gestisce internamente il device; possiamo passarlo qui
274
- # Se chosen_device è 'cpu' o 'cuda:0'
275
- # Alcune versioni usano 'device' in predict(), altre in load/attr; .predict supporta device
276
- seg_results = seg_model.predict(
277
- source=image_rgb,
278
- conf=conf_seg,
279
- device=chosen_device,
280
- verbose=False,
281
- )
282
- # Prendi il primo risultato
283
  r0 = seg_results[0] if isinstance(seg_results, (list, tuple)) else seg_results
284
  except Exception as e:
285
- raise gr.Error(f"Errore durante l'inferenza del modello di segmentazione: {e}")
286
 
287
- seg_vis_rgb, seg_counts_text = _draw_segmentation_masks_rgb(image_rgb, r0, target_class)
 
 
288
 
289
- return det_vis_rgb, det_counts_text, seg_vis_rgb, seg_counts_text
 
 
290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
  def build_app():
293
- with gr.Blocks(title="Berries counting and bunches segmentation - Owl-Nest") as demo:
294
  gr.Markdown(
295
- "- Carica un'immagine e lancia l'inferenza con due modelli YOLO.\n"
296
- "- Modello A dedicato al rilevamento e conteggio di acini.\n"
297
- "- Modello B dedicato alla segmentazione di grappoli."
 
 
298
  )
299
 
 
 
300
  with gr.Row():
301
- with gr.Column():
302
  img_in = gr.Image(label="Immagine", type="numpy")
303
-
304
- gr.Markdown("### Pesi modelli")
305
- weights_det = gr.Textbox(
306
- label="Percorso pesi Modello A",
307
- value="weights/berry.pt",
308
- placeholder="es. weights/best.pt",
309
- )
310
- weights_seg = gr.Textbox(
311
- label="Percorso pesi Modello B",
312
- value="weights/bunch.pt",
313
- placeholder="es. weights/seg.pt",
314
- )
315
-
316
- target = gr.Textbox(label="Classe target", value="berry")
317
-
318
- gr.Markdown("### Parametri modello A")
319
- with gr.Row():
320
- conf_det = gr.Slider(0.0, 1.0, value=0.35, step=0.01, label="Confidence (A)")
321
- device = gr.Dropdown(
322
- ["auto", "cuda:0", "cpu"],
323
- value="auto",
324
- label="Device",
325
  )
326
 
327
  with gr.Row():
328
- slice_h = gr.Slider(64, 2048, value=640, step=32, label="Slice H (A)")
329
- slice_w = gr.Slider(64, 2048, value=640, step=32, label="Slice W (A)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
 
331
  with gr.Row():
332
- overlap_h = gr.Slider(0.0, 0.9, value=0.10, step=0.01, label="Overlap H ratio (A)")
333
- overlap_w = gr.Slider(0.0, 0.9, value=0.10, step=0.01, label="Overlap W ratio (A)")
334
-
335
- gr.Markdown("### Parametri modello B")
336
- conf_seg = gr.Slider(0.0, 1.0, value=0.35, step=0.01, label="Confidence (B)")
337
 
338
- run_btn = gr.Button("Esegui inferenza", variant="primary")
 
 
 
339
 
340
- with gr.Column():
341
- gr.Markdown("### Risultato Modello A")
342
- img_out_det = gr.Image(label="Detections (solo bbox)", type="numpy")
343
- counts_out_det = gr.Textbox(label="Conteggi (A)", interactive=False)
 
 
 
 
 
 
 
 
344
 
345
- gr.Markdown("### Risultato Modello B")
346
- img_out_seg = gr.Image(label="Segmentazione (maschere)", type="numpy")
347
- counts_out_seg = gr.Textbox(label="Conteggi (B)", interactive=False)
 
 
 
 
 
 
348
 
349
- run_btn.click(
350
- infer_two_models,
351
  inputs=[
352
- img_in,
353
- weights_det, conf_det,
354
- slice_h, slice_w, overlap_h, overlap_w,
355
- device,
356
- target,
357
- weights_seg, conf_seg
358
  ],
359
- outputs=[img_out_det, counts_out_det, img_out_seg, counts_out_seg],
360
  )
361
 
362
- return demo
 
 
 
 
363
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
  if __name__ == "__main__":
366
  demo = build_app()
367
- # Su Spaces non è necessario specificare server_name o share
368
  demo.launch()
 
1
  import os
2
+
3
+ # Route caches to /tmp to avoid filling the Space persistent storage
4
+ os.environ.setdefault("HF_HOME", "/tmp/hf_home")
5
+ os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/hf_home/transformers")
6
+ os.environ.setdefault("HF_HUB_CACHE", "/tmp/hf_home/hub")
7
+ os.environ.setdefault("TORCH_HOME", "/tmp/torch_home")
8
+ os.environ.setdefault("PIP_DISABLE_PIP_VERSION_CHECK", "1")
9
+
10
  import cv2
11
+ import time
12
+ import shutil
13
  import numpy as np
14
  import gradio as gr
15
 
16
  from sahi import AutoDetectionModel
17
  from sahi.predict import get_sliced_prediction
18
 
19
+ # Try to import ultralytics for native segmentation
20
  try:
21
  from ultralytics import YOLO
22
+ _ULTRA_OK = True
23
  except Exception:
24
+ _ULTRA_OK = False
25
+
26
+ # Config
27
+ MAX_SIDE_PX = 70 # filtro lato massimo bbox per modello A (SAHI)
28
+ SEG_DEFAULT_ALPHA = 0.45
29
+
30
+ # Simple global caches to avoid reloading models each click
31
+ _DET_MODEL_CACHE = {} # key: (weights_path, device) -> AutoDetectionModel
32
+ _SEG_MODEL_CACHE = {} # key: weights_path -> YOLO
33
+
34
+ def _ensure_rgb(img: np.ndarray) -> np.ndarray:
35
+ if img is None:
36
+ return None
37
+ if img.ndim == 2:
38
+ return cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
39
+ if img.shape[2] == 4:
40
+ return cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
41
+ return img
42
+
43
+ def _choose_device(user_choice: str) -> str:
44
+ if user_choice != "auto":
45
+ return user_choice
46
+ try:
47
+ import torch
48
+ return "cuda:0" if torch.cuda.is_available() else "cpu"
49
+ except Exception:
50
+ return "cpu"
51
 
52
+ def _get_det_model(weights_path: str, device: str, conf: float):
53
  """
54
+ Returns a cached SAHI AutoDetectionModel. Updates confidence on the fly.
55
+ """
56
+ if not os.path.exists(weights_path):
57
+ raise gr.Error(f"Pesi detection non trovati: {weights_path}")
58
+ key = (weights_path, device)
59
+ model = _DET_MODEL_CACHE.get(key)
60
+ if model is None:
61
+ # SAHI uses yolov8 wrapper for Ultralytics models (works for v8/v9/v11)
62
+ try:
63
+ model = AutoDetectionModel.from_pretrained(
64
+ model_type="yolov8",
65
+ model_path=weights_path,
66
+ confidence_threshold=conf,
67
+ device=device,
68
+ )
69
+ except Exception:
70
+ # CPU fallback
71
+ model = AutoDetectionModel.from_pretrained(
72
+ model_type="yolov8",
73
+ model_path=weights_path,
74
+ confidence_threshold=conf,
75
+ device="cpu",
76
+ )
77
+ _DET_MODEL_CACHE[key] = model
78
+ else:
79
+ # Update confidence threshold if present
80
+ try:
81
+ model.confidence_threshold = float(conf)
82
+ except Exception:
83
+ pass
84
+ return model
85
+
86
+ def _get_seg_model(weights_path: str):
87
+ if not _ULTRA_OK:
88
+ raise gr.Error("Ultralytics non installato. Installa con: pip install ultralytics")
89
+ if not os.path.exists(weights_path):
90
+ raise gr.Error(f"Pesi segmentation non trovati: {weights_path}")
91
+ model = _SEG_MODEL_CACHE.get(weights_path)
92
+ if model is None:
93
+ model = YOLO(weights_path)
94
+ _SEG_MODEL_CACHE[weights_path] = model
95
+ return model
96
+
97
+ def _optimize_slicing_dims(H: int, W: int, slice_h: int, slice_w: int, overlap_h: float, overlap_w: float, auto_opt: bool):
98
+ if not auto_opt:
99
+ return int(slice_h), int(slice_w), float(overlap_h), float(overlap_w)
100
+ sh = min(int(slice_h), H)
101
+ sw = min(int(slice_w), W)
102
+ # If the image already fits in one slice, remove overlap to reduce work
103
+ oh = 0.0 if (H <= sh and W <= sw) else float(overlap_h)
104
+ ow = 0.0 if (H <= sh and W <= sw) else float(overlap_w)
105
+ return sh, sw, oh, ow
106
+
107
+ def _draw_boxes_overlay(image_rgb: np.ndarray, sahi_result, target_class: str, use_target: bool):
108
+ """
109
+ Returns overlay_rgb (H,W,3), alpha_mask (H,W) uint8, counts_text
110
+ Only draws rectangles (no labels). Filters boxes with max side > MAX_SIDE_PX.
111
  """
 
 
 
 
 
 
112
  H, W = image_rgb.shape[:2]
113
+ overlay = np.zeros((H, W, 3), dtype=np.uint8)
114
+ alpha = np.zeros((H, W), dtype=np.uint8)
115
 
 
 
116
  target_count = 0
117
  total_count = 0
118
+ object_predictions = getattr(sahi_result, "object_prediction_list", []) or []
 
119
 
120
  for item in object_predictions:
121
+ # parse bbox
122
  try:
123
  x1, y1, x2, y2 = map(int, item.bbox.to_xyxy())
124
  except Exception:
125
  x1, y1 = int(getattr(item.bbox, "minx", 0)), int(getattr(item.bbox, "miny", 0))
126
  x2, y2 = int(getattr(item.bbox, "maxx", 0)), int(getattr(item.bbox, "maxy", 0))
127
 
128
+ # clamp and normalize
129
+ x1 = max(0, min(x1, W - 1)); x2 = max(0, min(x2, W - 1))
130
+ y1 = max(0, min(y1, H - 1)); y2 = max(0, min(y2, H - 1))
131
+ if x2 < x1: x1, x2 = x2, x1
132
+ if y2 < y1: y1, y2 = y2, y1
 
 
 
 
 
 
133
 
 
134
  w = max(0, x2 - x1)
135
  h = max(0, y2 - y1)
136
  if w == 0 or h == 0:
137
  continue
 
 
138
  if max(w, h) > MAX_SIDE_PX:
139
  continue
140
 
 
141
  area = getattr(item.bbox, "area", w * h)
142
  try:
143
  area_val = float(area() if callable(area) else area)
 
147
  continue
148
 
149
  cls = getattr(item.category, "name", "unknown")
150
+ is_target = (cls == target_class) if use_target else False
151
+
152
+ # Colors in BGR for OpenCV, convert later when compositing
153
+ color_bgr = (0, 0, 255) if is_target and use_target else (0, 200, 0)
154
 
155
+ # Draw on overlay (BGR)
156
+ cv2.rectangle(overlay, (x1, y1), (x2, y2), color_bgr, 2)
157
+ cv2.rectangle(alpha, (x1, y1), (x2, y2), 255, 2)
158
 
159
  total_count += 1
160
  if is_target:
161
  target_count += 1
162
 
163
+ # Convert overlay BGR -> RGB
164
+ overlay_rgb = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
165
+ if use_target:
166
+ counts = f"target='{target_class}': {target_count} | totale: {total_count}"
167
+ else:
168
+ counts = f"totale: {total_count}"
169
+ return overlay_rgb, alpha, counts
170
 
171
+ def _draw_seg_overlay(image_rgb: np.ndarray, yolo_result, target_class: str, use_target: bool, fill_alpha: float = SEG_DEFAULT_ALPHA):
172
  """
173
+ Returns overlay_rgb (H,W,3), alpha_mask (H,W) uint8, counts_text for segmentation
174
+ - Fills masks with color (red for target, green for others if target enabled; else green)
175
+ - Draws contour opaque
 
176
  """
177
+ H, W = image_rgb.shape[:2]
178
+ overlay_bgr = np.zeros((H, W, 3), dtype=np.uint8)
179
+ alpha = np.zeros((H, W), dtype=np.uint8)
 
 
 
180
 
181
+ r = yolo_result
 
182
  names = getattr(r, "names", None)
183
  boxes = getattr(r, "boxes", None)
184
  masks = getattr(r, "masks", None)
185
 
186
  if boxes is None or len(boxes) == 0:
187
+ counts = f"target='{target_class}': 0 | totale: 0" if use_target else "totale: 0"
188
+ return cv2.cvtColor(overlay_bgr, cv2.COLOR_BGR2RGB), alpha, counts
189
 
 
190
  N = len(boxes)
 
 
191
  mask_data = None
192
  if masks is not None and getattr(masks, "data", None) is not None:
193
  try:
194
+ mask_data = masks.data # torch.Tensor [N, H, W] (prob/mask)
195
  except Exception:
196
  mask_data = None
197
 
198
  target_count = 0
199
  total_count = 0
200
+ fa255 = int(max(0.0, min(1.0, float(fill_alpha))) * 255)
201
 
 
202
  for i in range(N):
203
  try:
204
  cls_idx = int(boxes.cls[i].item())
 
208
  if isinstance(names, dict):
209
  cls_name = names.get(cls_idx, cls_name)
210
 
211
+ is_target = (cls_name == target_class) if use_target else False
212
+ color_bgr = (0, 0, 255) if is_target and use_target else (0, 200, 0)
 
213
 
 
214
  if mask_data is not None and i < len(mask_data):
215
  try:
216
  m = mask_data[i]
217
  m = m.detach().cpu().numpy()
218
+ m = (m > 0.5).astype(np.uint8) # binary mask
219
+
220
+ if m.shape[:2] != (H, W):
221
+ m = cv2.resize(m, (W, H), interpolation=cv2.INTER_NEAREST)
222
 
223
+ # Fill color where mask is 1
224
+ overlay_bgr[m == 1] = color_bgr
225
+ # Alpha for fill
226
+ alpha[m == 1] = np.maximum(alpha[m == 1], fa255)
227
 
228
+ # Contours opaque
229
  cnts, _ = cv2.findContours(m, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
230
+ cv2.drawContours(overlay_bgr, cnts, -1, color_bgr, 2)
231
+ # Draw contour alpha to 255
232
+ cv2.drawContours(alpha, cnts, -1, 255, 2)
233
  except Exception:
234
+ # fallback to bbox
235
  try:
236
  xyxy = boxes.xyxy[i].detach().cpu().numpy().astype(int)
237
  x1, y1, x2, y2 = map(int, xyxy)
238
+ cv2.rectangle(overlay_bgr, (x1, y1), (x2, y2), color_bgr, 2)
239
+ cv2.rectangle(alpha, (x1, y1), (x2, y2), 255, 2)
240
  except Exception:
241
  pass
242
  else:
243
+ # No mask: draw bbox
244
  try:
245
  xyxy = boxes.xyxy[i].detach().cpu().numpy().astype(int)
246
  x1, y1, x2, y2 = map(int, xyxy)
247
+ cv2.rectangle(overlay_bgr, (x1, y1), (x2, y2), color_bgr, 2)
248
+ cv2.rectangle(alpha, (x1, y1), (x2, y2), 255, 2)
249
  except Exception:
250
  pass
251
 
 
253
  if is_target:
254
  target_count += 1
255
 
256
+ overlay_rgb = cv2.cvtColor(overlay_bgr, cv2.COLOR_BGR2RGB)
257
+ if use_target:
258
+ counts = f"target='{target_class}': {target_count} | totale: {total_count}"
259
+ else:
260
+ counts = f"totale: {total_count}"
261
+ return overlay_rgb, alpha, counts
262
+
263
+ def _composite_layers(base_rgb: np.ndarray, layers: list):
 
 
 
 
 
 
 
 
 
 
264
  """
265
+ layers: list of dicts with keys:
266
+ - 'overlay' : np.ndarray HxWx3 RGB
267
+ - 'alpha' : np.ndarray HxW uint8
268
+ - 'ts' : float (timestamp), to control stacking order (oldest first)
269
+ Newest layer should be on top: sort by ts ascending and apply in order.
270
  """
271
+ if base_rgb is None:
272
+ return None
273
+ result = base_rgb.astype(np.float32)
274
+
275
+ # sort by timestamp (oldest first)
276
+ layers_sorted = sorted([l for l in layers if l is not None], key=lambda d: d["ts"])
277
+ for layer in layers_sorted:
278
+ ov = layer["overlay"].astype(np.float32)
279
+ a = (layer["alpha"].astype(np.float32) / 255.0)[..., None] # HxWx1
280
+ if ov.shape[:2] != result.shape[:2]:
281
+ ov = cv2.resize(ov, (result.shape[1], result.shape[0]), interpolation=cv2.INTER_LINEAR)
282
+ a = cv2.resize(a, (result.shape[1], result.shape[0]), interpolation=cv2.INTER_LINEAR)[..., None]
283
+ # alpha blend only where a > 0
284
+ result = ov * a + result * (1.0 - a)
285
+
286
+ return np.clip(result, 0, 255).astype(np.uint8)
287
+
288
+ def _sahi_predict(image_rgb: np.ndarray, det_model, slice_h, slice_w, overlap_h, overlap_w):
289
+ return get_sliced_prediction(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  image_rgb,
291
+ det_model,
292
  slice_height=int(slice_h),
293
  slice_width=int(slice_w),
294
  overlap_height_ratio=float(overlap_h),
 
297
  verbose=0,
298
  )
299
 
300
+ # Gradio callables
301
 
302
+ def on_image_upload(image, state):
303
+ """
304
+ Resetta gli overlay quando si carica una nuova immagine.
305
+ """
306
+ if image is None:
307
+ return None, {"base": None, "det": None, "seg": None, "det_counts": "", "seg_counts": ""}, "", ""
308
+ img_rgb = _ensure_rgb(image)
309
+ new_state = {"base": img_rgb, "det": None, "seg": None, "det_counts": "", "seg_counts": ""}
310
+ return img_rgb, new_state, "", ""
311
+
312
+ def run_det(
313
+ image, state,
314
+ weights_det_path, conf_det, slice_h, slice_w, overlap_h, overlap_w, device,
315
+ target_class, use_target, auto_opt_slice
316
+ ):
317
+ """
318
+ Esegue il modello A (SAHI detection) e aggiorna solo l'overlay 'det'.
319
+ Recompone l'immagine finale con entrambi i layer (det + seg) nell'ordine temporale.
320
+ """
321
+ if state is None or state.get("base") is None:
322
+ raise gr.Error("Carica prima un'immagine.")
323
+ base = state["base"]
324
+ H, W = base.shape[:2]
325
+ det_model = _get_det_model(weights_det_path, _choose_device(device), conf_det)
326
+ sh, sw, oh, ow = _optimize_slicing_dims(H, W, slice_h, slice_w, overlap_h, overlap_w, auto_opt_slice)
327
+ sahi_res = _sahi_predict(base, det_model, sh, sw, oh, ow)
328
+
329
+ overlay_rgb, alpha, counts = _draw_boxes_overlay(base, sahi_res, target_class, bool(use_target))
330
+
331
+ state["det"] = {"overlay": overlay_rgb, "alpha": alpha, "ts": time.time()}
332
+ state["det_counts"] = counts
333
+
334
+ layers = [state["det"], state.get("seg")]
335
+ composite = _composite_layers(base, layers)
336
+ return composite, state, state["det_counts"], state.get("seg_counts", "")
337
+
338
+ def run_seg(
339
+ image, state,
340
+ weights_seg_path, conf_seg, device,
341
+ target_class, use_target, seg_alpha
342
+ ):
343
+ """
344
+ Esegue il modello B (YOLO segmentation) e aggiorna solo l'overlay 'seg'.
345
+ Recompone l'immagine finale con entrambi i layer (det + seg) nell'ordine temporale.
346
+ """
347
+ if state is None or state.get("base") is None:
348
+ raise gr.Error("Carica prima un'immagine.")
349
+ base = state["base"]
350
+ seg_model = _get_seg_model(weights_seg_path)
351
+ # device is handled in predict
352
  try:
353
+ seg_results = seg_model.predict(source=base, conf=float(conf_seg), device=_choose_device(device), verbose=False)
 
 
 
 
 
 
 
 
 
 
354
  r0 = seg_results[0] if isinstance(seg_results, (list, tuple)) else seg_results
355
  except Exception as e:
356
+ raise gr.Error(f"Errore inferenza segmentation: {e}")
357
 
358
+ overlay_rgb, alpha, counts = _draw_seg_overlay(base, r0, target_class, bool(use_target), float(seg_alpha))
359
+ state["seg"] = {"overlay": overlay_rgb, "alpha": alpha, "ts": time.time()}
360
+ state["seg_counts"] = counts
361
 
362
+ layers = [state.get("det"), state["seg"]]
363
+ composite = _composite_layers(base, layers)
364
+ return composite, state, state.get("det_counts", ""), state["seg_counts"]
365
 
366
+ def clear_overlays(image, state):
367
+ if state is None or state.get("base") is None:
368
+ return None, {"base": None, "det": None, "seg": None, "det_counts": "", "seg_counts": ""}, "", ""
369
+ base = state["base"]
370
+ state["det"] = None
371
+ state["seg"] = None
372
+ state["det_counts"] = ""
373
+ state["seg_counts"] = ""
374
+ return base, state, "", ""
375
+
376
+ # Maintenance helpers
377
+
378
+ def _dir_size(path: str) -> int:
379
+ try:
380
+ total = 0
381
+ for root, _, files in os.walk(path):
382
+ for f in files:
383
+ fp = os.path.join(root, f)
384
+ try:
385
+ total += os.path.getsize(fp)
386
+ except Exception:
387
+ pass
388
+ return total
389
+ except Exception:
390
+ return 0
391
+
392
+ def _fmt_bytes(n: int) -> str:
393
+ for unit in ["B", "KB", "MB", "GB", "TB"]:
394
+ if n < 1024.0:
395
+ return f"{n:.1f} {unit}"
396
+ n /= 1024.0
397
+ return f"{n:.1f} PB"
398
+
399
+ def check_storage():
400
+ # Key cache locations
401
+ paths = [
402
+ os.path.expanduser("~/.cache/huggingface/hub"),
403
+ os.path.expanduser("~/.cache/torch"),
404
+ os.path.expanduser("~/.cache/pip"),
405
+ os.path.expanduser("~/.config/Ultralytics"),
406
+ "/tmp/hf_home/hub",
407
+ "/tmp/torch_home",
408
+ ]
409
+ lines = []
410
+ total_used = 0
411
+ for p in paths:
412
+ sz = _dir_size(p) if os.path.exists(p) else 0
413
+ total_used += sz
414
+ lines.append(f"{p}: {_fmt_bytes(sz)}")
415
+ try:
416
+ total, used, free = shutil.disk_usage("/")
417
+ disk_line = f"Disk usage: used {_fmt_bytes(used)} / total {_fmt_bytes(total)} (free {_fmt_bytes(free)})"
418
+ except Exception:
419
+ disk_line = "Disk usage: n/a"
420
+ return "Cache sizes:\n" + "\n".join(lines) + "\n" + disk_line
421
+
422
+ def clean_caches():
423
+ paths = [
424
+ os.path.expanduser("~/.cache/huggingface/hub"),
425
+ os.path.expanduser("~/.cache/torch"),
426
+ os.path.expanduser("~/.cache/pip"),
427
+ os.path.expanduser("~/.config/Ultralytics"),
428
+ "/tmp/hf_home",
429
+ "/tmp/torch_home",
430
+ ]
431
+ removed = []
432
+ for p in paths:
433
+ try:
434
+ if os.path.exists(p):
435
+ shutil.rmtree(p, ignore_errors=True)
436
+ removed.append(p)
437
+ except Exception:
438
+ pass
439
+ return "Removed:\n" + ("\n".join(removed) if removed else "(none)")
440
 
441
  def build_app():
442
+ with gr.Blocks(title="YOLOv11 SAHI Detection + YOLO Segmentation (dual overlays)") as demo:
443
  gr.Markdown(
444
+ "## Doppia inferenza su stessa immagine, overlay combinati\n"
445
+ "- Modello A: SAHI detection (usa pesi YOLOv11 seg come detection) — solo bbox, filtro lato > 70px.\n"
446
+ "- Modello B: YOLO segmentation nativo maschere riempite + contorno.\n"
447
+ "- Esegui i modelli con pulsanti separati; gli overlay si accumulano sull'immagine base (nuovo overlay sopra).\n"
448
+ "- Opzionale: disabilita l'evidenziazione della classe target se non ti serve."
449
  )
450
 
451
+ state = gr.State({"base": None, "det": None, "seg": None, "det_counts": "", "seg_counts": ""})
452
+
453
  with gr.Row():
454
+ with gr.Column(scale=1):
455
  img_in = gr.Image(label="Immagine", type="numpy")
456
+ with gr.Accordion("Pesi modelli", open=True):
457
+ weights_det = gr.Textbox(
458
+ label="Pesi Modello A (Detection + SAHI, .pt)",
459
+ value="weights/best.pt",
460
+ )
461
+ weights_seg = gr.Textbox(
462
+ label="Pesi Modello B (Segmentation, .pt)",
463
+ value="weights/seg.pt",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
  )
465
 
466
  with gr.Row():
467
+ target = gr.Textbox(label="Classe target", value="berry")
468
+ use_target = gr.Checkbox(label="Usa classe target", value=True)
469
+
470
+ with gr.Tab("Modello A — SAHI Detection"):
471
+ with gr.Row():
472
+ conf_det = gr.Slider(0.0, 1.0, value=0.35, step=0.01, label="Confidence (A)")
473
+ device_a = gr.Dropdown(["auto", "cuda:0", "cpu"], value="auto", label="Device")
474
+ with gr.Row():
475
+ slice_h = gr.Slider(64, 2048, value=640, step=32, label="Slice H (A)")
476
+ slice_w = gr.Slider(64, 2048, value=640, step=32, label="Slice W (A)")
477
+ with gr.Row():
478
+ overlap_h = gr.Slider(0.0, 0.9, value=0.10, step=0.01, label="Overlap H (A)")
479
+ overlap_w = gr.Slider(0.0, 0.9, value=0.10, step=0.01, label="Overlap W (A)")
480
+ auto_opt_slice = gr.Checkbox(label="Ottimizza slicing su immagini piccole", value=True)
481
+ btn_det = gr.Button("Esegui Modello A (SAHI)")
482
+
483
+ with gr.Tab("Modello B — YOLO Segmentation"):
484
+ with gr.Row():
485
+ conf_seg = gr.Slider(0.0, 1.0, value=0.35, step=0.01, label="Confidence (B)")
486
+ seg_alpha = gr.Slider(0.0, 1.0, value=SEG_DEFAULT_ALPHA, step=0.05, label="Alpha maschere (B)")
487
+ device_b = gr.Dropdown(["auto", "cuda:0", "cpu"], value="auto", label="Device")
488
+ btn_seg = gr.Button("Esegui Modello B (Seg)")
489
 
490
  with gr.Row():
491
+ btn_clear = gr.Button("Pulisci overlay", variant="secondary")
 
 
 
 
492
 
493
+ with gr.Accordion("Manutenzione spazio", open=False):
494
+ btn_check = gr.Button("Controlla storage")
495
+ btn_clean = gr.Button("Pulisci cache")
496
+ maint_out = gr.Textbox(label="Log manutenzione", interactive=False)
497
 
498
+ with gr.Column(scale=2):
499
+ img_out = gr.Image(label="Risultato combinato", type="numpy")
500
+ with gr.Row():
501
+ counts_out_det = gr.Textbox(label="Conteggi (A)", interactive=False)
502
+ counts_out_seg = gr.Textbox(label="Conteggi (B)", interactive=False)
503
+
504
+ # Wiring
505
+ img_in.change(
506
+ on_image_upload,
507
+ inputs=[img_in, state],
508
+ outputs=[img_out, state, counts_out_det, counts_out_seg],
509
+ )
510
 
511
+ btn_det.click(
512
+ run_det,
513
+ inputs=[
514
+ img_in, state,
515
+ weights_det, conf_det, slice_h, slice_w, overlap_h, overlap_w, device_a,
516
+ target, use_target, auto_opt_slice
517
+ ],
518
+ outputs=[img_out, state, counts_out_det, counts_out_seg],
519
+ )
520
 
521
+ btn_seg.click(
522
+ run_seg,
523
  inputs=[
524
+ img_in, state,
525
+ weights_seg, conf_seg, device_b,
526
+ target, use_target, seg_alpha
 
 
 
527
  ],
528
+ outputs=[img_out, state, counts_out_det, counts_out_seg],
529
  )
530
 
531
+ btn_clear.click(
532
+ clear_overlays,
533
+ inputs=[img_in, state],
534
+ outputs=[img_out, state, counts_out_det, counts_out_seg],
535
+ )
536
 
537
+ btn_check.click(
538
+ check_storage,
539
+ inputs=[],
540
+ outputs=[maint_out],
541
+ )
542
+
543
+ btn_clean.click(
544
+ clean_caches,
545
+ inputs=[],
546
+ outputs=[maint_out],
547
+ )
548
+
549
+ return demo
550
 
551
  if __name__ == "__main__":
552
  demo = build_app()
 
553
  demo.launch()