Spaces:
Paused
Paused
| # pages/1_OR1_visual_classification.py | |
| import streamlit as st, torch, numpy as np, cv2, pandas as pd | |
| from PIL import Image | |
| from utils import get_device, download_checkpoint | |
| from segment_anything import sam_model_registry, SamAutomaticMaskGenerator | |
| from transformers import CLIPProcessor, CLIPModel | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 1) Fix torchvision NMS (errori GPU) βββββββββββββββββββββββββββ | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| device = get_device() | |
| if device.type == "cuda": | |
| try: | |
| from torchvision.ops import boxes as _bx | |
| _orig_nms = _bx._batched_nms_vanilla | |
| def _batched_nms_fix(boxes, scores, idxs, thr): | |
| return _orig_nms(boxes, scores, idxs.to(boxes.device), thr) | |
| _bx._batched_nms_vanilla = _batched_nms_fix | |
| except Exception as e: | |
| st.warning(f"Patch NMS non riuscita: {e}") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 2) Funzioni utili ββββββββββββββββββββββββββββββββββββββββββββ | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def is_hole(mask_bin, min_circ=0.7, min_area=200): | |
| cnts, _ = cv2.findContours(mask_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if not cnts: | |
| return False | |
| cnt = max(cnts, key=cv2.contourArea) | |
| area = cv2.contourArea(cnt) | |
| peri = cv2.arcLength(cnt, True) or 1 | |
| circ = 4 * np.pi * area / peri**2 | |
| return area >= min_area and circ >= min_circ | |
| def load_sam(pts, iou, stab, cpu=False): | |
| url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" | |
| fname = "sam_vit_h_4b8939.pth" | |
| download_checkpoint(url, fname) | |
| model_dev = torch.device("cpu") if cpu else device | |
| sam = sam_model_registry["vit_h"](checkpoint=fname).to(model_dev).eval() | |
| return SamAutomaticMaskGenerator( | |
| sam, | |
| points_per_side=pts, | |
| pred_iou_thresh=iou, | |
| stability_score_thresh=stab, | |
| ) | |
| def load_clip(): | |
| name = "laion/CLIP-ViT-L-14-laion2B-s32B-b82K" | |
| return (CLIPModel.from_pretrained(name).to(device), | |
| CLIPProcessor.from_pretrained(name)) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 3) UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.set_page_config(page_title="OR1 β Riconoscimento", layout="wide") | |
| st.title("π§© OR1Β βΒ Riconoscimento visivo & classificazione") | |
| st.sidebar.success(f"Device: **{device.type}**") | |
| with st.sidebar.expander("Impostazioni avanzate", expanded=False): | |
| pts = st.slider("Points per side", 0, 128, 32) | |
| iou = st.slider("Pred IoU threshold", 0.0, 1.0, 0.8) | |
| stab = st.slider("Stability score threshold", 0.0, 1.0, 0.9) | |
| min_area_px = st.number_input("Area minima segmento (px)", 0, 50_000, 500) | |
| circ_thr = st.slider("Soglia circolaritΓ foro", 0.0, 1.0, 0.7) | |
| force_cpu = st.checkbox("Forza SAM su CPU (workaround errori GPU)") | |
| uploaded = st.file_uploader("Carica immagini JPG/PNG", type=["jpg","jpeg","png"], | |
| accept_multiple_files=True) | |
| labels = [x.strip() for x in st.text_input( | |
| "Etichette (separate da virgola)", | |
| "lamiera, scanalatura rettangolare").split(",") if x.strip()] | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 4) Analisi βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if uploaded and st.button("π Analizza immagini"): | |
| clip_model, clip_proc = load_clip() | |
| sam_gen = load_sam(pts, iou, stab, cpu=force_cpu) | |
| global_bar = st.progress(0, text="Analisi immagini") | |
| all_rows = [] | |
| for idx_file, file in enumerate(uploaded, start=1): | |
| img_pil = Image.open(file).convert("RGB") | |
| img_np = np.array(img_pil) | |
| H, W = img_np.shape[:2] | |
| st.subheader(f"π {file.name}") | |
| st.image(img_pil, caption="Immagine originale", use_container_width=True) | |
| # ---- segmentazione con barra interna ---- | |
| with st.spinner("Segmentazione + classificazioneβ¦"): | |
| masks = sam_gen.generate(img_np) | |
| masks = [m for m in masks if m["area"] >= min_area_px] | |
| inner_bar = st.progress(0) | |
| masks_info = [] | |
| for j, m in enumerate(masks, start=1): | |
| segm = m["segmentation"] | |
| mask_bin = (segm * 255).astype(np.uint8) | |
| # CLIP | |
| inputs = clip_proc(text=labels, images=img_pil, | |
| return_tensors="pt", padding=True) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| logits = clip_model(**inputs).logits_per_image | |
| probs = logits.softmax(1) | |
| label = labels[int(probs.argmax())] | |
| conf = float(probs.max()) | |
| if is_hole(mask_bin, circ_thr): | |
| label = "foro" | |
| masks_info.append({"Indice": j-1, "Label": label, | |
| "Confidence": round(conf,3), | |
| "Area(px)": int(m["area"])}) | |
| inner_bar.progress(j/len(masks), text=f"Segmenti analizzati: {j}/{len(masks)}") | |
| inner_bar.empty() | |
| # ---- unisci lamiera piΓΉ grande ---- | |
| lamiera = [r for r in masks_info if r["Label"] == "lamiera"] | |
| if lamiera: | |
| idx_big = max(lamiera, key=lambda r: r["Area(px)"])["Indice"] | |
| union = np.zeros((H,W), bool) | |
| for r in lamiera: | |
| union |= masks[r["Indice"]]["segmentation"] | |
| masks[idx_big]["segmentation"] = union | |
| masks[idx_big]["area"] = int(union.sum()) | |
| masks_info = [r for r in masks_info | |
| if r["Label"] != "lamiera" or r["Indice"] == idx_big] | |
| # ---- tabella + overlay ---- | |
| df = pd.DataFrame(masks_info) | |
| st.dataframe(df, use_container_width=True) | |
| overlay = img_np.copy() | |
| for r in masks_info: | |
| seg_bin = (masks[r["Indice"]]["segmentation"]>0.5).astype(np.uint8)*255 | |
| cnts,_ = cv2.findContours(seg_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| color = tuple(np.random.randint(0,255,3).tolist()) | |
| for c in cnts: | |
| cv2.drawContours(overlay,[c],-1,color,2) | |
| x,y,w,h = cv2.boundingRect(c) | |
| cv2.putText(overlay,r["Label"],(x,y-5), | |
| cv2.FONT_HERSHEY_SIMPLEX,0.5,color,2) | |
| st.image(overlay, caption="Overlay etichette", use_container_width=True) | |
| for r in masks_info: | |
| r["File"] = file.name | |
| all_rows += masks_info | |
| # aggiorna barra globale | |
| global_bar.progress(idx_file/len(uploaded), | |
| text=f"Immagini elaborate: {idx_file}/{len(uploaded)}") | |
| global_bar.empty() | |
| if all_rows: | |
| csv = pd.DataFrame(all_rows).to_csv(index=False).encode() | |
| st.download_button("π₯ Scarica risultati (CSV)", csv, | |
| "OR1_results.csv", "text/csv") | |