Update app.py
Browse files
app.py
CHANGED
|
@@ -45,7 +45,6 @@ def load_sam_mask_generator(points_per_side, pred_iou_thresh, stability_score_th
|
|
| 45 |
Carica il modello SAM e crea un SamAutomaticMaskGenerator
|
| 46 |
con parametri personalizzabili (passati da Streamlit).
|
| 47 |
"""
|
| 48 |
-
# Importa libreria dal repo segment_anything
|
| 49 |
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
|
| 50 |
|
| 51 |
# Scarichiamo il checkpoint
|
|
@@ -92,7 +91,7 @@ def analyze_mask_geometry(mask_bin):
|
|
| 92 |
"""Calcola bounding box e circolarità di una mask (0/255)."""
|
| 93 |
contours, _ = cv2.findContours(mask_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 94 |
if not contours:
|
| 95 |
-
return {"bbox": (0,0,0,0), "circolarita": 0.0}
|
| 96 |
|
| 97 |
cnt = max(contours, key=cv2.contourArea)
|
| 98 |
x, y, w, h = cv2.boundingRect(cnt)
|
|
@@ -125,7 +124,7 @@ def classify_mask(image_np, mask_bin, candidate_labels, clip_processor, clip_mod
|
|
| 125 |
return_tensors="pt",
|
| 126 |
padding=True
|
| 127 |
)
|
| 128 |
-
#
|
| 129 |
inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}
|
| 130 |
|
| 131 |
# Inference CLIP
|
|
@@ -156,7 +155,7 @@ def overlay_masks_auto(image, masks_with_labels):
|
|
| 156 |
import matplotlib.pyplot as plt
|
| 157 |
|
| 158 |
image_np = np.array(image)
|
| 159 |
-
fig, ax = plt.subplots(figsize=(6,6))
|
| 160 |
ax.imshow(image_np)
|
| 161 |
|
| 162 |
color_map = {}
|
|
@@ -167,14 +166,14 @@ def overlay_masks_auto(image, masks_with_labels):
|
|
| 167 |
|
| 168 |
# Assegna un colore a ogni label
|
| 169 |
if label not in color_map:
|
| 170 |
-
color_map[label] = np.random.randint(0,255,(3,), dtype=np.uint8)/255.0
|
| 171 |
|
| 172 |
color = color_map[label]
|
| 173 |
contours, _ = cv2.findContours(mask_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 174 |
for cnt in contours:
|
| 175 |
cnt = cnt.squeeze()
|
| 176 |
if cnt.ndim == 2 and len(cnt) > 1:
|
| 177 |
-
ax.plot(cnt[:,0], cnt[:,1], color=color, linewidth=2, label=label)
|
| 178 |
|
| 179 |
ax.axis("off")
|
| 180 |
buf = BytesIO()
|
|
@@ -189,7 +188,7 @@ def overlay_masks_auto(image, masks_with_labels):
|
|
| 189 |
##############################################################
|
| 190 |
def union_lamiera_masks(df, h, w, lamiera_label="lamiera"):
|
| 191 |
"""
|
| 192 |
-
|
| 193 |
maschera, e sostituisce le righe multiple con una sola riga unificata.
|
| 194 |
"""
|
| 195 |
df_lamiera = df[df["Label"] == lamiera_label]
|
|
@@ -197,9 +196,9 @@ def union_lamiera_masks(df, h, w, lamiera_label="lamiera"):
|
|
| 197 |
return df # niente da unire
|
| 198 |
|
| 199 |
lamiera_mask_total = np.zeros((h, w), dtype=np.uint8)
|
| 200 |
-
for
|
| 201 |
segm = row["segmentation"] # mask 0..1
|
| 202 |
-
segm_bin = (segm*255).astype(np.uint8)
|
| 203 |
lamiera_mask_total = cv2.bitwise_or(lamiera_mask_total, segm_bin)
|
| 204 |
|
| 205 |
# Creiamo una riga unificata
|
|
@@ -212,9 +211,9 @@ def union_lamiera_masks(df, h, w, lamiera_label="lamiera"):
|
|
| 212 |
"Area(px)": int(area_sum),
|
| 213 |
"BoundingBox": str(geom_info["bbox"]),
|
| 214 |
"Circolarita": round(geom_info["circolarita"], 3),
|
| 215 |
-
"segmentation": (lamiera_mask_total / 255.0).astype(np.float32)
|
| 216 |
}
|
| 217 |
-
|
| 218 |
df_no_lam = df[df["Label"] != lamiera_label]
|
| 219 |
df_final = pd.concat([df_no_lam, pd.DataFrame([new_row])], ignore_index=True)
|
| 220 |
return df_final
|
|
@@ -230,31 +229,56 @@ def main():
|
|
| 230 |
con pass di geometria. Se la GPU è disponibile, verrà usata per velocizzare i calcoli.
|
| 231 |
""")
|
| 232 |
|
| 233 |
-
# Parametri SAM personalizzabili via Streamlit sidebar
|
| 234 |
st.sidebar.header("Parametri SAM")
|
| 235 |
-
points_per_side = st.sidebar.slider(
|
| 236 |
-
|
| 237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
|
| 239 |
# Filtraggio maschere in base all'area minima
|
| 240 |
st.sidebar.header("Filtraggio Maschere")
|
| 241 |
-
min_area = st.sidebar.number_input(
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
-
#
|
| 244 |
mask_generator = load_sam_mask_generator(points_per_side, pred_iou_thresh, stability_score_thresh)
|
| 245 |
clip_model, clip_processor = load_clip_model()
|
| 246 |
|
| 247 |
# Upload immagini
|
| 248 |
-
uploaded_files = st.file_uploader(
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
|
| 254 |
if "segmentations_auto" not in st.session_state:
|
| 255 |
st.session_state["segmentations_auto"] = {}
|
| 256 |
|
| 257 |
-
#
|
|
|
|
|
|
|
| 258 |
if uploaded_files:
|
| 259 |
for file in uploaded_files:
|
| 260 |
st.image(file, caption=file.name, use_column_width=True)
|
|
@@ -271,18 +295,24 @@ def main():
|
|
| 271 |
# Filtra maschere troppo piccole
|
| 272 |
filtered_masks_info = [m for m in masks_info if m["area"] >= min_area]
|
| 273 |
|
|
|
|
| 274 |
st.session_state["segmentations_auto"][file.name] = {
|
| 275 |
"image": image_pil,
|
| 276 |
"masks": filtered_masks_info
|
| 277 |
}
|
| 278 |
-
percent = int((i+1)/len(uploaded_files)*100)
|
| 279 |
progress_seg.progress(percent)
|
| 280 |
st.success("Maschere generate con successo!")
|
| 281 |
|
| 282 |
-
#
|
|
|
|
|
|
|
| 283 |
if st.session_state["segmentations_auto"]:
|
| 284 |
st.header("Classificazione con CLIP")
|
| 285 |
-
st.markdown("
|
|
|
|
|
|
|
|
|
|
| 286 |
default_prompts = "lamiera, foro circolare, scanalatura rettangolare"
|
| 287 |
label_prompts = st.text_input("Prompt Zero-Shot", value=default_prompts)
|
| 288 |
candidate_labels = [lp.strip() for lp in label_prompts.split(",") if lp.strip()]
|
|
@@ -300,10 +330,11 @@ def main():
|
|
| 300 |
|
| 301 |
masks_with_labels = []
|
| 302 |
for idx_m, m_dict in enumerate(masks_info):
|
| 303 |
-
segm = m_dict["segmentation"] # 2D mask 0..1
|
| 304 |
area_px = m_dict["area"]
|
| 305 |
mask_bin = (segm * 255).astype(np.uint8)
|
| 306 |
|
|
|
|
| 307 |
label_pred, conf, geom_info = classify_mask(
|
| 308 |
image_np,
|
| 309 |
mask_bin,
|
|
@@ -312,6 +343,8 @@ def main():
|
|
| 312 |
clip_model,
|
| 313 |
do_geometry
|
| 314 |
)
|
|
|
|
|
|
|
| 315 |
row = {
|
| 316 |
"Indice": idx_m,
|
| 317 |
"Label": label_pred,
|
|
@@ -326,33 +359,35 @@ def main():
|
|
| 326 |
results_rows.append(row)
|
| 327 |
masks_with_labels.append({"mask": segm, "label": label_pred})
|
| 328 |
|
| 329 |
-
# Generiamo l'overlay
|
| 330 |
overlay_img = overlay_masks_auto(image_pil, masks_with_labels)
|
| 331 |
st.subheader(f"File: {fn}")
|
| 332 |
st.image(overlay_img, caption="Maschere colorate in base alle etichette")
|
| 333 |
|
| 334 |
-
progress_class.progress(int((idx_i+1)/len(seg_items)*100))
|
| 335 |
|
| 336 |
# Convertiamo in DataFrame
|
| 337 |
df = pd.DataFrame(results_rows)
|
| 338 |
-
|
| 339 |
-
#
|
| 340 |
if do_union_lamiera and not df.empty:
|
| 341 |
df = union_lamiera_masks(df, H, W, lamiera_label="lamiera")
|
| 342 |
|
|
|
|
|
|
|
|
|
|
| 343 |
st.write("**Tabella Risultati**")
|
| 344 |
-
st.dataframe(df)
|
| 345 |
|
| 346 |
-
# Conteggio
|
| 347 |
-
st.write("**Conteggio**:")
|
| 348 |
if not df.empty:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
count_labels = df["Label"].value_counts()
|
| 350 |
st.write(count_labels)
|
| 351 |
-
else:
|
| 352 |
-
st.write("Nessuna maschera valida trovata.")
|
| 353 |
|
| 354 |
-
|
| 355 |
-
if not df.empty:
|
| 356 |
csv_buf = df.to_csv(index=False).encode("utf-8")
|
| 357 |
st.download_button(
|
| 358 |
label="Scarica CSV",
|
|
@@ -360,6 +395,8 @@ def main():
|
|
| 360 |
file_name="risultati_clip_automatic.csv",
|
| 361 |
mime="text/csv"
|
| 362 |
)
|
|
|
|
|
|
|
| 363 |
|
| 364 |
|
| 365 |
if __name__ == "__main__":
|
|
|
|
| 45 |
Carica il modello SAM e crea un SamAutomaticMaskGenerator
|
| 46 |
con parametri personalizzabili (passati da Streamlit).
|
| 47 |
"""
|
|
|
|
| 48 |
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
|
| 49 |
|
| 50 |
# Scarichiamo il checkpoint
|
|
|
|
| 91 |
"""Calcola bounding box e circolarità di una mask (0/255)."""
|
| 92 |
contours, _ = cv2.findContours(mask_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 93 |
if not contours:
|
| 94 |
+
return {"bbox": (0, 0, 0, 0), "circolarita": 0.0}
|
| 95 |
|
| 96 |
cnt = max(contours, key=cv2.contourArea)
|
| 97 |
x, y, w, h = cv2.boundingRect(cnt)
|
|
|
|
| 124 |
return_tensors="pt",
|
| 125 |
padding=True
|
| 126 |
)
|
| 127 |
+
# Sposta i tensori su GPU
|
| 128 |
inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}
|
| 129 |
|
| 130 |
# Inference CLIP
|
|
|
|
| 155 |
import matplotlib.pyplot as plt
|
| 156 |
|
| 157 |
image_np = np.array(image)
|
| 158 |
+
fig, ax = plt.subplots(figsize=(6, 6))
|
| 159 |
ax.imshow(image_np)
|
| 160 |
|
| 161 |
color_map = {}
|
|
|
|
| 166 |
|
| 167 |
# Assegna un colore a ogni label
|
| 168 |
if label not in color_map:
|
| 169 |
+
color_map[label] = np.random.randint(0, 255, (3,), dtype=np.uint8) / 255.0
|
| 170 |
|
| 171 |
color = color_map[label]
|
| 172 |
contours, _ = cv2.findContours(mask_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 173 |
for cnt in contours:
|
| 174 |
cnt = cnt.squeeze()
|
| 175 |
if cnt.ndim == 2 and len(cnt) > 1:
|
| 176 |
+
ax.plot(cnt[:, 0], cnt[:, 1], color=color, linewidth=2, label=label)
|
| 177 |
|
| 178 |
ax.axis("off")
|
| 179 |
buf = BytesIO()
|
|
|
|
| 188 |
##############################################################
|
| 189 |
def union_lamiera_masks(df, h, w, lamiera_label="lamiera"):
|
| 190 |
"""
|
| 191 |
+
Unisce tutte le maschere con etichetta 'lamiera' in una sola
|
| 192 |
maschera, e sostituisce le righe multiple con una sola riga unificata.
|
| 193 |
"""
|
| 194 |
df_lamiera = df[df["Label"] == lamiera_label]
|
|
|
|
| 196 |
return df # niente da unire
|
| 197 |
|
| 198 |
lamiera_mask_total = np.zeros((h, w), dtype=np.uint8)
|
| 199 |
+
for _, row in df_lamiera.iterrows():
|
| 200 |
segm = row["segmentation"] # mask 0..1
|
| 201 |
+
segm_bin = (segm * 255).astype(np.uint8)
|
| 202 |
lamiera_mask_total = cv2.bitwise_or(lamiera_mask_total, segm_bin)
|
| 203 |
|
| 204 |
# Creiamo una riga unificata
|
|
|
|
| 211 |
"Area(px)": int(area_sum),
|
| 212 |
"BoundingBox": str(geom_info["bbox"]),
|
| 213 |
"Circolarita": round(geom_info["circolarita"], 3),
|
| 214 |
+
"segmentation": (lamiera_mask_total / 255.0).astype(np.float32)
|
| 215 |
}
|
| 216 |
+
|
| 217 |
df_no_lam = df[df["Label"] != lamiera_label]
|
| 218 |
df_final = pd.concat([df_no_lam, pd.DataFrame([new_row])], ignore_index=True)
|
| 219 |
return df_final
|
|
|
|
| 229 |
con pass di geometria. Se la GPU è disponibile, verrà usata per velocizzare i calcoli.
|
| 230 |
""")
|
| 231 |
|
| 232 |
+
# Parametri SAM personalizzabili via Streamlit sidebar, con tooltip (help).
|
| 233 |
st.sidebar.header("Parametri SAM")
|
| 234 |
+
points_per_side = st.sidebar.slider(
|
| 235 |
+
"Points per side", 0, 128, 32,
|
| 236 |
+
help="Numero di punti su ogni lato della bounding box di input usati come prompt. "
|
| 237 |
+
"Più punti => segmentazione più dettagliata e più lenta. (Default=32)"
|
| 238 |
+
)
|
| 239 |
+
pred_iou_thresh = st.sidebar.slider(
|
| 240 |
+
"Pred IoU Threshold", 0.0, 1.0, 0.8,
|
| 241 |
+
help="Soglia minima di confidenza IoU per le maschere. 0.8 è un buon compromesso."
|
| 242 |
+
)
|
| 243 |
+
stability_score_thresh = st.sidebar.slider(
|
| 244 |
+
"Stability Score Threshold", 0.0, 1.0, 0.9,
|
| 245 |
+
help="Soglia di stabilità della maschera. Valori alti filtrano maschere poco stabili. (Default=0.9)"
|
| 246 |
+
)
|
| 247 |
|
| 248 |
# Filtraggio maschere in base all'area minima
|
| 249 |
st.sidebar.header("Filtraggio Maschere")
|
| 250 |
+
min_area = st.sidebar.number_input(
|
| 251 |
+
"Area minima (px)", min_value=0, value=100,
|
| 252 |
+
help="Filtra via maschere con area (in pixel) inferiore a questa soglia."
|
| 253 |
+
)
|
| 254 |
|
| 255 |
+
# Caricamento e inizializzazione modelli
|
| 256 |
mask_generator = load_sam_mask_generator(points_per_side, pred_iou_thresh, stability_score_thresh)
|
| 257 |
clip_model, clip_processor = load_clip_model()
|
| 258 |
|
| 259 |
# Upload immagini
|
| 260 |
+
uploaded_files = st.file_uploader(
|
| 261 |
+
"Carica immagini (JPG/PNG)",
|
| 262 |
+
type=["jpg", "jpeg", "png"],
|
| 263 |
+
accept_multiple_files=True
|
| 264 |
+
)
|
| 265 |
+
do_geometry = st.checkbox(
|
| 266 |
+
"Calcolo circolarità e bounding box",
|
| 267 |
+
value=True,
|
| 268 |
+
help="Se spuntato, calcola parametri come circolarità e bounding box per ciascuna maschera."
|
| 269 |
+
)
|
| 270 |
+
do_union_lamiera = st.checkbox(
|
| 271 |
+
"Unisci maschere lamiera in 1 sola maschera",
|
| 272 |
+
value=True,
|
| 273 |
+
help="Se spuntato, tutte le maschere classificate come 'lamiera' vengono unite in un'unica maschera cumulativa."
|
| 274 |
+
)
|
| 275 |
|
| 276 |
if "segmentations_auto" not in st.session_state:
|
| 277 |
st.session_state["segmentations_auto"] = {}
|
| 278 |
|
| 279 |
+
########################
|
| 280 |
+
# A) Genera Maschere #
|
| 281 |
+
########################
|
| 282 |
if uploaded_files:
|
| 283 |
for file in uploaded_files:
|
| 284 |
st.image(file, caption=file.name, use_column_width=True)
|
|
|
|
| 295 |
# Filtra maschere troppo piccole
|
| 296 |
filtered_masks_info = [m for m in masks_info if m["area"] >= min_area]
|
| 297 |
|
| 298 |
+
# Salviamo nel session_state
|
| 299 |
st.session_state["segmentations_auto"][file.name] = {
|
| 300 |
"image": image_pil,
|
| 301 |
"masks": filtered_masks_info
|
| 302 |
}
|
| 303 |
+
percent = int((i + 1) / len(uploaded_files) * 100)
|
| 304 |
progress_seg.progress(percent)
|
| 305 |
st.success("Maschere generate con successo!")
|
| 306 |
|
| 307 |
+
########################
|
| 308 |
+
# B) Classifica SAM #
|
| 309 |
+
########################
|
| 310 |
if st.session_state["segmentations_auto"]:
|
| 311 |
st.header("Classificazione con CLIP")
|
| 312 |
+
st.markdown("""
|
| 313 |
+
Inserisci le etichette di classificazione separate da virgola.
|
| 314 |
+
Esempio: "lamiera, foro circolare, scanalatura rettangolare, albero".
|
| 315 |
+
""")
|
| 316 |
default_prompts = "lamiera, foro circolare, scanalatura rettangolare"
|
| 317 |
label_prompts = st.text_input("Prompt Zero-Shot", value=default_prompts)
|
| 318 |
candidate_labels = [lp.strip() for lp in label_prompts.split(",") if lp.strip()]
|
|
|
|
| 330 |
|
| 331 |
masks_with_labels = []
|
| 332 |
for idx_m, m_dict in enumerate(masks_info):
|
| 333 |
+
segm = m_dict["segmentation"] # 2D mask (0..1)
|
| 334 |
area_px = m_dict["area"]
|
| 335 |
mask_bin = (segm * 255).astype(np.uint8)
|
| 336 |
|
| 337 |
+
# Classifica con CLIP
|
| 338 |
label_pred, conf, geom_info = classify_mask(
|
| 339 |
image_np,
|
| 340 |
mask_bin,
|
|
|
|
| 343 |
clip_model,
|
| 344 |
do_geometry
|
| 345 |
)
|
| 346 |
+
|
| 347 |
+
# Costruiamo la riga da aggiungere al DF
|
| 348 |
row = {
|
| 349 |
"Indice": idx_m,
|
| 350 |
"Label": label_pred,
|
|
|
|
| 359 |
results_rows.append(row)
|
| 360 |
masks_with_labels.append({"mask": segm, "label": label_pred})
|
| 361 |
|
| 362 |
+
# Generiamo l'overlay delle maschere
|
| 363 |
overlay_img = overlay_masks_auto(image_pil, masks_with_labels)
|
| 364 |
st.subheader(f"File: {fn}")
|
| 365 |
st.image(overlay_img, caption="Maschere colorate in base alle etichette")
|
| 366 |
|
| 367 |
+
progress_class.progress(int((idx_i + 1) / len(seg_items) * 100))
|
| 368 |
|
| 369 |
# Convertiamo in DataFrame
|
| 370 |
df = pd.DataFrame(results_rows)
|
| 371 |
+
|
| 372 |
+
# Unione lamiera (opzionale)
|
| 373 |
if do_union_lamiera and not df.empty:
|
| 374 |
df = union_lamiera_masks(df, H, W, lamiera_label="lamiera")
|
| 375 |
|
| 376 |
+
#################################
|
| 377 |
+
# Visualizzazione e Download CSV
|
| 378 |
+
#################################
|
| 379 |
st.write("**Tabella Risultati**")
|
|
|
|
| 380 |
|
|
|
|
|
|
|
| 381 |
if not df.empty:
|
| 382 |
+
# Evitiamo l'errore ArrowInvalid rimuovendo la colonna "segmentation" dalla sola vista
|
| 383 |
+
df_display = df.drop(columns=["segmentation"]).copy()
|
| 384 |
+
st.dataframe(df_display)
|
| 385 |
+
|
| 386 |
+
st.write("**Conteggio**:")
|
| 387 |
count_labels = df["Label"].value_counts()
|
| 388 |
st.write(count_labels)
|
|
|
|
|
|
|
| 389 |
|
| 390 |
+
# Download CSV (includiamo anche "segmentation")
|
|
|
|
| 391 |
csv_buf = df.to_csv(index=False).encode("utf-8")
|
| 392 |
st.download_button(
|
| 393 |
label="Scarica CSV",
|
|
|
|
| 395 |
file_name="risultati_clip_automatic.csv",
|
| 396 |
mime="text/csv"
|
| 397 |
)
|
| 398 |
+
else:
|
| 399 |
+
st.write("Nessuna maschera valida trovata.")
|
| 400 |
|
| 401 |
|
| 402 |
if __name__ == "__main__":
|