Upcycling_AI / pages /OR1_visual_classification.py
valegro's picture
Update pages/OR1_visual_classification.py
a37245e verified
# pages/1_OR1_visual_classification.py
import streamlit as st, torch, numpy as np, cv2, pandas as pd
from PIL import Image
from utils import get_device, download_checkpoint
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
from transformers import CLIPProcessor, CLIPModel
# ────────────────────────────────────────────────────────────────
# 1) Fix torchvision NMS (errori GPU) ───────────────────────────
# ────────────────────────────────────────────────────────────────
device = get_device()
if device.type == "cuda":
try:
from torchvision.ops import boxes as _bx
_orig_nms = _bx._batched_nms_vanilla
def _batched_nms_fix(boxes, scores, idxs, thr):
return _orig_nms(boxes, scores, idxs.to(boxes.device), thr)
_bx._batched_nms_vanilla = _batched_nms_fix
except Exception as e:
st.warning(f"Patch NMS non riuscita: {e}")
# ────────────────────────────────────────────────────────────────
# 2) Funzioni utili ────────────────────────────────────────────
# ────────────────────────────────────────────────────────────────
def is_hole(mask_bin, min_circ=0.7, min_area=200):
cnts, _ = cv2.findContours(mask_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not cnts:
return False
cnt = max(cnts, key=cv2.contourArea)
area = cv2.contourArea(cnt)
peri = cv2.arcLength(cnt, True) or 1
circ = 4 * np.pi * area / peri**2
return area >= min_area and circ >= min_circ
@st.cache_resource(show_spinner=False)
def load_sam(pts, iou, stab, cpu=False):
url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
fname = "sam_vit_h_4b8939.pth"
download_checkpoint(url, fname)
model_dev = torch.device("cpu") if cpu else device
sam = sam_model_registry["vit_h"](checkpoint=fname).to(model_dev).eval()
return SamAutomaticMaskGenerator(
sam,
points_per_side=pts,
pred_iou_thresh=iou,
stability_score_thresh=stab,
)
@st.cache_resource(show_spinner=False)
def load_clip():
name = "laion/CLIP-ViT-L-14-laion2B-s32B-b82K"
return (CLIPModel.from_pretrained(name).to(device),
CLIPProcessor.from_pretrained(name))
# ────────────────────────────────────────────────────────────────
# 3) UI ────────────────────────────────────────────────────────
# ────────────────────────────────────────────────────────────────
st.set_page_config(page_title="OR1 – Riconoscimento", layout="wide")
st.title("🧩 OR1 – Riconoscimento visivo & classificazione")
st.sidebar.success(f"Device: **{device.type}**")
with st.sidebar.expander("Impostazioni avanzate", expanded=False):
pts = st.slider("Points per side", 0, 128, 32)
iou = st.slider("Pred IoU threshold", 0.0, 1.0, 0.8)
stab = st.slider("Stability score threshold", 0.0, 1.0, 0.9)
min_area_px = st.number_input("Area minima segmento (px)", 0, 50_000, 500)
circ_thr = st.slider("Soglia circolaritΓ  foro", 0.0, 1.0, 0.7)
force_cpu = st.checkbox("Forza SAM su CPU (workaround errori GPU)")
uploaded = st.file_uploader("Carica immagini JPG/PNG", type=["jpg","jpeg","png"],
accept_multiple_files=True)
labels = [x.strip() for x in st.text_input(
"Etichette (separate da virgola)",
"lamiera, scanalatura rettangolare").split(",") if x.strip()]
# ────────────────────────────────────────────────────────────────
# 4) Analisi ───────────────────────────────────────────────────
# ────────────────────────────────────────────────────────────────
if uploaded and st.button("πŸ”Ž Analizza immagini"):
clip_model, clip_proc = load_clip()
sam_gen = load_sam(pts, iou, stab, cpu=force_cpu)
global_bar = st.progress(0, text="Analisi immagini")
all_rows = []
for idx_file, file in enumerate(uploaded, start=1):
img_pil = Image.open(file).convert("RGB")
img_np = np.array(img_pil)
H, W = img_np.shape[:2]
st.subheader(f"πŸ“Ž {file.name}")
st.image(img_pil, caption="Immagine originale", use_container_width=True)
# ---- segmentazione con barra interna ----
with st.spinner("Segmentazione + classificazione…"):
masks = sam_gen.generate(img_np)
masks = [m for m in masks if m["area"] >= min_area_px]
inner_bar = st.progress(0)
masks_info = []
for j, m in enumerate(masks, start=1):
segm = m["segmentation"]
mask_bin = (segm * 255).astype(np.uint8)
# CLIP
inputs = clip_proc(text=labels, images=img_pil,
return_tensors="pt", padding=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
logits = clip_model(**inputs).logits_per_image
probs = logits.softmax(1)
label = labels[int(probs.argmax())]
conf = float(probs.max())
if is_hole(mask_bin, circ_thr):
label = "foro"
masks_info.append({"Indice": j-1, "Label": label,
"Confidence": round(conf,3),
"Area(px)": int(m["area"])})
inner_bar.progress(j/len(masks), text=f"Segmenti analizzati: {j}/{len(masks)}")
inner_bar.empty()
# ---- unisci lamiera piΓΉ grande ----
lamiera = [r for r in masks_info if r["Label"] == "lamiera"]
if lamiera:
idx_big = max(lamiera, key=lambda r: r["Area(px)"])["Indice"]
union = np.zeros((H,W), bool)
for r in lamiera:
union |= masks[r["Indice"]]["segmentation"]
masks[idx_big]["segmentation"] = union
masks[idx_big]["area"] = int(union.sum())
masks_info = [r for r in masks_info
if r["Label"] != "lamiera" or r["Indice"] == idx_big]
# ---- tabella + overlay ----
df = pd.DataFrame(masks_info)
st.dataframe(df, use_container_width=True)
overlay = img_np.copy()
for r in masks_info:
seg_bin = (masks[r["Indice"]]["segmentation"]>0.5).astype(np.uint8)*255
cnts,_ = cv2.findContours(seg_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
color = tuple(np.random.randint(0,255,3).tolist())
for c in cnts:
cv2.drawContours(overlay,[c],-1,color,2)
x,y,w,h = cv2.boundingRect(c)
cv2.putText(overlay,r["Label"],(x,y-5),
cv2.FONT_HERSHEY_SIMPLEX,0.5,color,2)
st.image(overlay, caption="Overlay etichette", use_container_width=True)
for r in masks_info:
r["File"] = file.name
all_rows += masks_info
# aggiorna barra globale
global_bar.progress(idx_file/len(uploaded),
text=f"Immagini elaborate: {idx_file}/{len(uploaded)}")
global_bar.empty()
if all_rows:
csv = pd.DataFrame(all_rows).to_csv(index=False).encode()
st.download_button("πŸ“₯ Scarica risultati (CSV)", csv,
"OR1_results.csv", "text/csv")