colomboMk commited on
Commit
a1d7bb7
·
verified ·
1 Parent(s): 715a3cf

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +368 -0
app.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import gradio as gr
5
+
6
+ from sahi import AutoDetectionModel
7
+ from sahi.predict import get_sliced_prediction
8
+
9
+ # Prova a importare ultralytics per il modello di segmentazione nativo (senza SAHI)
10
+ try:
11
+ from ultralytics import YOLO
12
+ _ULTRALYTICS_AVAILABLE = True
13
+ except Exception:
14
+ _ULTRALYTICS_AVAILABLE = False
15
+
16
+ # Soglia massima consentita per il lato della bbox (in pixel) per il modello con SAHI
17
+ MAX_SIDE_PX = 70
18
+
19
+
20
+ def _draw_boxes_rgb(image_rgb: np.ndarray, result, target_class: str):
21
+ """
22
+ Disegna solo le bbox sul frame RGB (niente etichette testuali).
23
+ - Evidenzia in rosso la classe target
24
+ - Le altre classi in verde
25
+ - Scarta le bbox con lato (max tra width e height) > MAX_SIDE_PX
26
+ Restituisce (immagine_annotata_RGB, counts_text)
27
+ """
28
+ # Garantisci 3 canali
29
+ if image_rgb.ndim == 2:
30
+ image_rgb = cv2.cvtColor(image_rgb, cv2.COLOR_GRAY2RGB)
31
+ elif image_rgb.shape[2] == 4:
32
+ image_rgb = cv2.cvtColor(image_rgb, cv2.COLOR_RGBA2RGB)
33
+
34
+ H, W = image_rgb.shape[:2]
35
+
36
+ # OpenCV disegna in BGR
37
+ vis_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
38
+ target_count = 0
39
+ total_count = 0
40
+
41
+ object_predictions = getattr(result, "object_prediction_list", []) or []
42
+
43
+ for item in object_predictions:
44
+ # bbox
45
+ try:
46
+ x1, y1, x2, y2 = map(int, item.bbox.to_xyxy())
47
+ except Exception:
48
+ x1, y1 = int(getattr(item.bbox, "minx", 0)), int(getattr(item.bbox, "miny", 0))
49
+ x2, y2 = int(getattr(item.bbox, "maxx", 0)), int(getattr(item.bbox, "maxy", 0))
50
+
51
+ # Clamp ai bordi immagine
52
+ x1 = max(0, min(x1, W - 1))
53
+ y1 = max(0, min(y1, H - 1))
54
+ x2 = max(0, min(x2, W - 1))
55
+ y2 = max(0, min(y2, H - 1))
56
+
57
+ # Normalizza coordinate in caso invertite
58
+ if x2 < x1:
59
+ x1, x2 = x2, x1
60
+ if y2 < y1:
61
+ y1, y2 = y2, y1
62
+
63
+ # Scarta bbox non valide
64
+ w = max(0, x2 - x1)
65
+ h = max(0, y2 - y1)
66
+ if w == 0 or h == 0:
67
+ continue
68
+
69
+ # Scarta le bbox con lato maggiore della soglia
70
+ if max(w, h) > MAX_SIDE_PX:
71
+ continue
72
+
73
+ # Scarta bbox con area non positiva (per sicurezza)
74
+ area = getattr(item.bbox, "area", w * h)
75
+ try:
76
+ area_val = float(area() if callable(area) else area)
77
+ except Exception:
78
+ area_val = float(w * h)
79
+ if area_val <= 0:
80
+ continue
81
+
82
+ cls = getattr(item.category, "name", "unknown")
83
+ is_target = (cls == target_class)
84
+
85
+ color_bgr = (0, 0, 255) if is_target else (0, 200, 0) # rosso per target, verde per altre
86
+ cv2.rectangle(vis_bgr, (x1, y1), (x2, y2), color_bgr, 2)
87
+ # Nessuna label testuale
88
+
89
+ total_count += 1
90
+ if is_target:
91
+ target_count += 1
92
+
93
+ vis_rgb = cv2.cvtColor(vis_bgr, cv2.COLOR_BGR2RGB)
94
+ counts_text = f"target='{target_class}': {target_count} | totale: {total_count}"
95
+ return vis_rgb, counts_text
96
+
97
+
98
+ def _draw_segmentation_masks_rgb(image_rgb: np.ndarray, ulty_result, target_class: str, alpha: float = 0.45):
99
+ """
100
+ Disegna le maschere di segmentazione (niente etichette testuali).
101
+ - Evidenzia in rosso la classe target
102
+ - Le altre classi in verde
103
+ - Restituisce (immagine_annotata_RGB, counts_text)
104
+ """
105
+ if image_rgb.ndim == 2:
106
+ image_rgb = cv2.cvtColor(image_rgb, cv2.COLOR_GRAY2RGB)
107
+ elif image_rgb.shape[2] == 4:
108
+ image_rgb = cv2.cvtColor(image_rgb, cv2.COLOR_RGBA2RGB)
109
+
110
+ vis_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
111
+
112
+ # Estrarre info dal risultato Ultralytics
113
+ r = ulty_result
114
+ names = getattr(r, "names", None)
115
+ boxes = getattr(r, "boxes", None)
116
+ masks = getattr(r, "masks", None)
117
+
118
+ if boxes is None or len(boxes) == 0:
119
+ # Nessun oggetto
120
+ return cv2.cvtColor(vis_bgr, cv2.COLOR_BGR2RGB), f"target='{target_class}': 0 | totale: 0"
121
+
122
+ # Numero di istanze
123
+ N = len(boxes)
124
+
125
+ # Prepara maschere (se presenti)
126
+ mask_data = None
127
+ if masks is not None and getattr(masks, "data", None) is not None:
128
+ try:
129
+ mask_data = masks.data # torch.Tensor [N, H, W]
130
+ except Exception:
131
+ mask_data = None
132
+
133
+ target_count = 0
134
+ total_count = 0
135
+
136
+ # Loop istanze
137
+ for i in range(N):
138
+ try:
139
+ cls_idx = int(boxes.cls[i].item())
140
+ except Exception:
141
+ cls_idx = -1
142
+ cls_name = str(cls_idx)
143
+ if isinstance(names, dict):
144
+ cls_name = names.get(cls_idx, cls_name)
145
+
146
+ is_target = (cls_name == target_class)
147
+
148
+ color_bgr = (0, 0, 255) if is_target else (0, 200, 0) # rosso per target, verde per altre
149
+
150
+ # Disegna mask se disponibile
151
+ if mask_data is not None and i < len(mask_data):
152
+ try:
153
+ m = mask_data[i]
154
+ m = m.detach().cpu().numpy()
155
+ m = (m > 0.5).astype(np.uint8) # binarizza
156
+ # Assicurare dimensioni identiche a immagine
157
+ if m.shape[:2] != vis_bgr.shape[:2]:
158
+ m = cv2.resize(m, (vis_bgr.shape[1], vis_bgr.shape[0]), interpolation=cv2.INTER_NEAREST)
159
+
160
+ # Overlay colore
161
+ overlay = np.zeros_like(vis_bgr, dtype=np.uint8)
162
+ overlay[m.astype(bool)] = color_bgr
163
+ vis_bgr = cv2.addWeighted(overlay, alpha, vis_bgr, 1 - alpha, 0)
164
+
165
+ # Contorno
166
+ cnts, _ = cv2.findContours(m, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
167
+ cv2.drawContours(vis_bgr, cnts, -1, color_bgr, 2)
168
+ except Exception:
169
+ # fallback: disegna il bbox
170
+ try:
171
+ xyxy = boxes.xyxy[i].detach().cpu().numpy().astype(int)
172
+ x1, y1, x2, y2 = map(int, xyxy)
173
+ cv2.rectangle(vis_bgr, (x1, y1), (x2, y2), color_bgr, 2)
174
+ except Exception:
175
+ pass
176
+ else:
177
+ # Nessuna mask: disegna solo bbox
178
+ try:
179
+ xyxy = boxes.xyxy[i].detach().cpu().numpy().astype(int)
180
+ x1, y1, x2, y2 = map(int, xyxy)
181
+ cv2.rectangle(vis_bgr, (x1, y1), (x2, y2), color_bgr, 2)
182
+ except Exception:
183
+ pass
184
+
185
+ total_count += 1
186
+ if is_target:
187
+ target_count += 1
188
+
189
+ vis_rgb = cv2.cvtColor(vis_bgr, cv2.COLOR_BGR2RGB)
190
+ counts_text = f"target='{target_class}': {target_count} | totale: {total_count}"
191
+ return vis_rgb, counts_text
192
+
193
+
194
+ def infer_two_models(
195
+ image: np.ndarray,
196
+ weights_det_path: str,
197
+ conf_det: float,
198
+ slice_h: int,
199
+ slice_w: int,
200
+ overlap_h: float,
201
+ overlap_w: float,
202
+ device: str,
203
+ target_class: str,
204
+ weights_seg_path: str,
205
+ conf_seg: float,
206
+ ):
207
+ """
208
+ Esegue inferenza su una singola immagine con due modelli:
209
+ - Modello A (Detection via SAHI): usa pesi YOLOv11 segment come detection, disegna solo bbox, filtra box con lato > MAX_SIDE_PX
210
+ - Modello B (Segmentation nativo YOLO): nessun SAHI, disegna solo maschere (niente etichette)
211
+ Restituisce 4 output: (img_det, counts_det, img_seg, counts_seg)
212
+ """
213
+ if image is None:
214
+ raise gr.Error("Devi caricare un'immagine.")
215
+
216
+ if not weights_det_path or not os.path.exists(weights_det_path):
217
+ raise gr.Error(f"File pesi (Detection/SAHI) non trovato: {weights_det_path}")
218
+
219
+ if not weights_seg_path or not os.path.exists(weights_seg_path):
220
+ raise gr.Error(f"File pesi (Segmentation) non trovato: {weights_seg_path}")
221
+
222
+ if not _ULTRALYTICS_AVAILABLE:
223
+ raise gr.Error("Ultralytics non è installato per il modello di segmentazione. Installa con: pip install ultralytics")
224
+
225
+ image_rgb = image.copy()
226
+ model_type = "yolov11"
227
+
228
+ # Scelta automatica device se 'auto'
229
+ chosen_device = device
230
+ if device == "auto":
231
+ try:
232
+ import torch
233
+ chosen_device = "cuda:0" if torch.cuda.is_available() else "cpu"
234
+ except Exception:
235
+ chosen_device = "cpu"
236
+
237
+ # =========================
238
+ # Modello A: Detection con SAHI (boxes only)
239
+ # =========================
240
+ try:
241
+ detection_model = AutoDetectionModel.from_pretrained(
242
+ model_type=model_type,
243
+ model_path=weights_det_path,
244
+ confidence_threshold=conf_det,
245
+ device=chosen_device,
246
+ )
247
+ except Exception:
248
+ detection_model = AutoDetectionModel.from_pretrained(
249
+ model_type=model_type,
250
+ model_path=weights_det_path,
251
+ confidence_threshold=conf_det,
252
+ device="cpu",
253
+ )
254
+
255
+ sahi_result = get_sliced_prediction(
256
+ image_rgb,
257
+ detection_model,
258
+ slice_height=int(slice_h),
259
+ slice_width=int(slice_w),
260
+ overlap_height_ratio=float(overlap_h),
261
+ overlap_width_ratio=float(overlap_w),
262
+ postprocess_class_agnostic=False,
263
+ verbose=0,
264
+ )
265
+
266
+ det_vis_rgb, det_counts_text = _draw_boxes_rgb(image_rgb, sahi_result, target_class)
267
+
268
+ # =========================
269
+ # Modello B: YOLO Segmentation nativo (no SAHI)
270
+ # =========================
271
+ try:
272
+ seg_model = YOLO(weights_seg_path)
273
+ # Nota: Ultralytics gestisce internamente il device; possiamo passarlo qui
274
+ # Se chosen_device è 'cpu' o 'cuda:0'
275
+ # Alcune versioni usano 'device' in predict(), altre in load/attr; .predict supporta device
276
+ seg_results = seg_model.predict(
277
+ source=image_rgb,
278
+ conf=conf_seg,
279
+ device=chosen_device,
280
+ verbose=False,
281
+ )
282
+ # Prendi il primo risultato
283
+ r0 = seg_results[0] if isinstance(seg_results, (list, tuple)) else seg_results
284
+ except Exception as e:
285
+ raise gr.Error(f"Errore durante l'inferenza del modello di segmentazione: {e}")
286
+
287
+ seg_vis_rgb, seg_counts_text = _draw_segmentation_masks_rgb(image_rgb, r0, target_class)
288
+
289
+ return det_vis_rgb, det_counts_text, seg_vis_rgb, seg_counts_text
290
+
291
+
292
+ def build_app():
293
+ with gr.Blocks(title="Berries counting and bunches segmentation - Owl-Nest") as demo:
294
+ gr.Markdown(
295
+ "- Carica un'immagine e lancia l'inferenza con due modelli YOLO.\n"
296
+ "- Modello A dedicato al rilevamento e conteggio di acini.\n"
297
+ "- Modello B dedicato alla segmentazione di grappoli."
298
+ )
299
+
300
+ with gr.Row():
301
+ with gr.Column():
302
+ img_in = gr.Image(label="Immagine", type="numpy")
303
+
304
+ gr.Markdown("### Pesi modelli")
305
+ weights_det = gr.Textbox(
306
+ label="Percorso pesi Modello A",
307
+ value="weights/berry.pt",
308
+ placeholder="es. weights/best.pt",
309
+ )
310
+ weights_seg = gr.Textbox(
311
+ label="Percorso pesi Modello B",
312
+ value="weights/bunch.pt",
313
+ placeholder="es. weights/seg.pt",
314
+ )
315
+
316
+ target = gr.Textbox(label="Classe target", value="berry")
317
+
318
+ gr.Markdown("### Parametri modello A")
319
+ with gr.Row():
320
+ conf_det = gr.Slider(0.0, 1.0, value=0.35, step=0.01, label="Confidence (A)")
321
+ device = gr.Dropdown(
322
+ ["auto", "cuda:0", "cpu"],
323
+ value="auto",
324
+ label="Device",
325
+ )
326
+
327
+ with gr.Row():
328
+ slice_h = gr.Slider(64, 2048, value=640, step=32, label="Slice H (A)")
329
+ slice_w = gr.Slider(64, 2048, value=640, step=32, label="Slice W (A)")
330
+
331
+ with gr.Row():
332
+ overlap_h = gr.Slider(0.0, 0.9, value=0.10, step=0.01, label="Overlap H ratio (A)")
333
+ overlap_w = gr.Slider(0.0, 0.9, value=0.10, step=0.01, label="Overlap W ratio (A)")
334
+
335
+ gr.Markdown("### Parametri modello B")
336
+ conf_seg = gr.Slider(0.0, 1.0, value=0.35, step=0.01, label="Confidence (B)")
337
+
338
+ run_btn = gr.Button("Esegui inferenza", variant="primary")
339
+
340
+ with gr.Column():
341
+ gr.Markdown("### Risultato Modello A")
342
+ img_out_det = gr.Image(label="Detections (solo bbox)", type="numpy")
343
+ counts_out_det = gr.Textbox(label="Conteggi (A)", interactive=False)
344
+
345
+ gr.Markdown("### Risultato Modello B")
346
+ img_out_seg = gr.Image(label="Segmentazione (maschere)", type="numpy")
347
+ counts_out_seg = gr.Textbox(label="Conteggi (B)", interactive=False)
348
+
349
+ run_btn.click(
350
+ infer_two_models,
351
+ inputs=[
352
+ img_in,
353
+ weights_det, conf_det,
354
+ slice_h, slice_w, overlap_h, overlap_w,
355
+ device,
356
+ target,
357
+ weights_seg, conf_seg
358
+ ],
359
+ outputs=[img_out_det, counts_out_det, img_out_seg, counts_out_seg],
360
+ )
361
+
362
+ return demo
363
+
364
+
365
+ if __name__ == "__main__":
366
+ demo = build_app()
367
+ # Su Spaces non è necessario specificare server_name o share
368
+ demo.launch()