# pages/1_OR1_visual_classification.py import streamlit as st, torch, numpy as np, cv2, pandas as pd from PIL import Image from utils import get_device, download_checkpoint from segment_anything import sam_model_registry, SamAutomaticMaskGenerator from transformers import CLIPProcessor, CLIPModel # ──────────────────────────────────────────────────────────────── # 1) Fix torchvision NMS (errori GPU) ─────────────────────────── # ──────────────────────────────────────────────────────────────── device = get_device() if device.type == "cuda": try: from torchvision.ops import boxes as _bx _orig_nms = _bx._batched_nms_vanilla def _batched_nms_fix(boxes, scores, idxs, thr): return _orig_nms(boxes, scores, idxs.to(boxes.device), thr) _bx._batched_nms_vanilla = _batched_nms_fix except Exception as e: st.warning(f"Patch NMS non riuscita: {e}") # ──────────────────────────────────────────────────────────────── # 2) Funzioni utili ──────────────────────────────────────────── # ──────────────────────────────────────────────────────────────── def is_hole(mask_bin, min_circ=0.7, min_area=200): cnts, _ = cv2.findContours(mask_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if not cnts: return False cnt = max(cnts, key=cv2.contourArea) area = cv2.contourArea(cnt) peri = cv2.arcLength(cnt, True) or 1 circ = 4 * np.pi * area / peri**2 return area >= min_area and circ >= min_circ @st.cache_resource(show_spinner=False) def load_sam(pts, iou, stab, cpu=False): url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" fname = "sam_vit_h_4b8939.pth" download_checkpoint(url, fname) model_dev = torch.device("cpu") if cpu else device sam = sam_model_registry["vit_h"](checkpoint=fname).to(model_dev).eval() return SamAutomaticMaskGenerator( sam, points_per_side=pts, pred_iou_thresh=iou, stability_score_thresh=stab, ) @st.cache_resource(show_spinner=False) def load_clip(): name = "laion/CLIP-ViT-L-14-laion2B-s32B-b82K" return (CLIPModel.from_pretrained(name).to(device), CLIPProcessor.from_pretrained(name)) # ──────────────────────────────────────────────────────────────── # 3) UI ──────────────────────────────────────────────────────── # ──────────────────────────────────────────────────────────────── st.set_page_config(page_title="OR1 – Riconoscimento", layout="wide") st.title("🧩 OR1 – Riconoscimento visivo & classificazione") st.sidebar.success(f"Device: **{device.type}**") with st.sidebar.expander("Impostazioni avanzate", expanded=False): pts = st.slider("Points per side", 0, 128, 32) iou = st.slider("Pred IoU threshold", 0.0, 1.0, 0.8) stab = st.slider("Stability score threshold", 0.0, 1.0, 0.9) min_area_px = st.number_input("Area minima segmento (px)", 0, 50_000, 500) circ_thr = st.slider("Soglia circolarità foro", 0.0, 1.0, 0.7) force_cpu = st.checkbox("Forza SAM su CPU (workaround errori GPU)") uploaded = st.file_uploader("Carica immagini JPG/PNG", type=["jpg","jpeg","png"], accept_multiple_files=True) labels = [x.strip() for x in st.text_input( "Etichette (separate da virgola)", "lamiera, scanalatura rettangolare").split(",") if x.strip()] # ──────────────────────────────────────────────────────────────── # 4) Analisi ─────────────────────────────────────────────────── # ──────────────────────────────────────────────────────────────── if uploaded and st.button("🔎 Analizza immagini"): clip_model, clip_proc = load_clip() sam_gen = load_sam(pts, iou, stab, cpu=force_cpu) global_bar = st.progress(0, text="Analisi immagini") all_rows = [] for idx_file, file in enumerate(uploaded, start=1): img_pil = Image.open(file).convert("RGB") img_np = np.array(img_pil) H, W = img_np.shape[:2] st.subheader(f"📎 {file.name}") st.image(img_pil, caption="Immagine originale", use_container_width=True) # ---- segmentazione con barra interna ---- with st.spinner("Segmentazione + classificazione…"): masks = sam_gen.generate(img_np) masks = [m for m in masks if m["area"] >= min_area_px] inner_bar = st.progress(0) masks_info = [] for j, m in enumerate(masks, start=1): segm = m["segmentation"] mask_bin = (segm * 255).astype(np.uint8) # CLIP inputs = clip_proc(text=labels, images=img_pil, return_tensors="pt", padding=True) inputs = {k: v.to(device) for k, v in inputs.items()} logits = clip_model(**inputs).logits_per_image probs = logits.softmax(1) label = labels[int(probs.argmax())] conf = float(probs.max()) if is_hole(mask_bin, circ_thr): label = "foro" masks_info.append({"Indice": j-1, "Label": label, "Confidence": round(conf,3), "Area(px)": int(m["area"])}) inner_bar.progress(j/len(masks), text=f"Segmenti analizzati: {j}/{len(masks)}") inner_bar.empty() # ---- unisci lamiera più grande ---- lamiera = [r for r in masks_info if r["Label"] == "lamiera"] if lamiera: idx_big = max(lamiera, key=lambda r: r["Area(px)"])["Indice"] union = np.zeros((H,W), bool) for r in lamiera: union |= masks[r["Indice"]]["segmentation"] masks[idx_big]["segmentation"] = union masks[idx_big]["area"] = int(union.sum()) masks_info = [r for r in masks_info if r["Label"] != "lamiera" or r["Indice"] == idx_big] # ---- tabella + overlay ---- df = pd.DataFrame(masks_info) st.dataframe(df, use_container_width=True) overlay = img_np.copy() for r in masks_info: seg_bin = (masks[r["Indice"]]["segmentation"]>0.5).astype(np.uint8)*255 cnts,_ = cv2.findContours(seg_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) color = tuple(np.random.randint(0,255,3).tolist()) for c in cnts: cv2.drawContours(overlay,[c],-1,color,2) x,y,w,h = cv2.boundingRect(c) cv2.putText(overlay,r["Label"],(x,y-5), cv2.FONT_HERSHEY_SIMPLEX,0.5,color,2) st.image(overlay, caption="Overlay etichette", use_container_width=True) for r in masks_info: r["File"] = file.name all_rows += masks_info # aggiorna barra globale global_bar.progress(idx_file/len(uploaded), text=f"Immagini elaborate: {idx_file}/{len(uploaded)}") global_bar.empty() if all_rows: csv = pd.DataFrame(all_rows).to_csv(index=False).encode() st.download_button("📥 Scarica risultati (CSV)", csv, "OR1_results.csv", "text/csv")