File size: 8,515 Bytes
54d4548
a37245e
bf3f0aa
 
 
 
 
a37245e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54d4548
a37245e
 
54d4548
a37245e
54d4548
a37245e
 
 
54d4548
 
a37245e
bf3f0aa
 
 
a37245e
 
bf3f0aa
 
a37245e
 
 
bf3f0aa
 
54d4548
bf3f0aa
54d4548
a37245e
 
54d4548
a37245e
 
 
54d4548
 
a37245e
bf3f0aa
a37245e
 
 
54d4548
 
a37245e
 
 
 
 
 
 
 
 
 
 
 
bf3f0aa
a37245e
 
 
 
54d4548
 
a37245e
bf3f0aa
54d4548
 
 
bf3f0aa
54d4548
bf3f0aa
a37245e
 
 
54d4548
bf3f0aa
a37245e
bf3f0aa
a37245e
bf3f0aa
 
 
54d4548
a37245e
 
bf3f0aa
a37245e
 
 
 
 
 
54d4548
bf3f0aa
a37245e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf3f0aa
 
 
54d4548
 
a37245e
 
 
 
 
 
 
 
54d4548
bf3f0aa
 
 
54d4548
bf3f0aa
a37245e
 
 
 
 
 
54d4548
a37245e
54d4548
a37245e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
# pages/1_OR1_visual_classification.py
import streamlit as st, torch, numpy as np, cv2, pandas as pd
from PIL import Image
from utils import get_device, download_checkpoint
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
from transformers import CLIPProcessor, CLIPModel

# ────────────────────────────────────────────────────────────────
# 1)  Fix torchvision NMS (errori GPU) ───────────────────────────
# ────────────────────────────────────────────────────────────────
device = get_device()
if device.type == "cuda":
    try:
        from torchvision.ops import boxes as _bx
        _orig_nms = _bx._batched_nms_vanilla

        def _batched_nms_fix(boxes, scores, idxs, thr):
            return _orig_nms(boxes, scores, idxs.to(boxes.device), thr)

        _bx._batched_nms_vanilla = _batched_nms_fix
    except Exception as e:
        st.warning(f"Patch NMS non riuscita: {e}")

# ────────────────────────────────────────────────────────────────
# 2)  Funzioni utili  ────────────────────────────────────────────
# ────────────────────────────────────────────────────────────────
def is_hole(mask_bin, min_circ=0.7, min_area=200):
    cnts, _ = cv2.findContours(mask_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if not cnts:
        return False
    cnt = max(cnts, key=cv2.contourArea)
    area = cv2.contourArea(cnt)
    peri = cv2.arcLength(cnt, True) or 1
    circ = 4 * np.pi * area / peri**2
    return area >= min_area and circ >= min_circ

@st.cache_resource(show_spinner=False)
def load_sam(pts, iou, stab, cpu=False):
    url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
    fname = "sam_vit_h_4b8939.pth"
    download_checkpoint(url, fname)
    model_dev = torch.device("cpu") if cpu else device
    sam = sam_model_registry["vit_h"](checkpoint=fname).to(model_dev).eval()
    return SamAutomaticMaskGenerator(
        sam,
        points_per_side=pts,
        pred_iou_thresh=iou,
        stability_score_thresh=stab,
    )

@st.cache_resource(show_spinner=False)
def load_clip():
    name = "laion/CLIP-ViT-L-14-laion2B-s32B-b82K"
    return (CLIPModel.from_pretrained(name).to(device),
            CLIPProcessor.from_pretrained(name))

# ────────────────────────────────────────────────────────────────
# 3)   UI  ────────────────────────────────────────────────────────
# ────────────────────────────────────────────────────────────────
st.set_page_config(page_title="OR1 – Riconoscimento", layout="wide")
st.title("🧩 OR1 – Riconoscimento visivo & classificazione")
st.sidebar.success(f"Device: **{device.type}**")

with st.sidebar.expander("Impostazioni avanzate", expanded=False):
    pts  = st.slider("Points per side", 0, 128, 32)
    iou  = st.slider("Pred IoU threshold", 0.0, 1.0, 0.8)
    stab = st.slider("Stability score threshold", 0.0, 1.0, 0.9)
    min_area_px = st.number_input("Area minima segmento (px)", 0, 50_000, 500)
    circ_thr    = st.slider("Soglia circolaritΓ  foro", 0.0, 1.0, 0.7)
    force_cpu   = st.checkbox("Forza SAM su CPU (workaround errori GPU)")

uploaded = st.file_uploader("Carica immagini JPG/PNG", type=["jpg","jpeg","png"],
                            accept_multiple_files=True)
labels = [x.strip() for x in st.text_input(
            "Etichette (separate da virgola)",
            "lamiera, scanalatura rettangolare").split(",") if x.strip()]

# ────────────────────────────────────────────────────────────────
# 4)  Analisi  ───────────────────────────────────────────────────
# ────────────────────────────────────────────────────────────────
if uploaded and st.button("πŸ”Ž Analizza immagini"):
    clip_model, clip_proc = load_clip()
    sam_gen = load_sam(pts, iou, stab, cpu=force_cpu)

    global_bar = st.progress(0, text="Analisi immagini")
    all_rows = []

    for idx_file, file in enumerate(uploaded, start=1):
        img_pil = Image.open(file).convert("RGB")
        img_np  = np.array(img_pil)
        H, W = img_np.shape[:2]

        st.subheader(f"πŸ“Ž {file.name}")
        st.image(img_pil, caption="Immagine originale", use_container_width=True)

        # ---- segmentazione con barra interna ----
        with st.spinner("Segmentazione + classificazione…"):
            masks = sam_gen.generate(img_np)
        masks = [m for m in masks if m["area"] >= min_area_px]

        inner_bar = st.progress(0)
        masks_info = []
        for j, m in enumerate(masks, start=1):
            segm = m["segmentation"]
            mask_bin = (segm * 255).astype(np.uint8)

            # CLIP
            inputs = clip_proc(text=labels, images=img_pil,
                               return_tensors="pt", padding=True)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            logits = clip_model(**inputs).logits_per_image
            probs  = logits.softmax(1)
            label  = labels[int(probs.argmax())]
            conf   = float(probs.max())

            if is_hole(mask_bin, circ_thr):
                label = "foro"

            masks_info.append({"Indice": j-1, "Label": label,
                               "Confidence": round(conf,3),
                               "Area(px)": int(m["area"])})

            inner_bar.progress(j/len(masks), text=f"Segmenti analizzati: {j}/{len(masks)}")

        inner_bar.empty()

        # ---- unisci lamiera piΓΉ grande ----
        lamiera = [r for r in masks_info if r["Label"] == "lamiera"]
        if lamiera:
            idx_big = max(lamiera, key=lambda r: r["Area(px)"])["Indice"]
            union = np.zeros((H,W), bool)
            for r in lamiera:
                union |= masks[r["Indice"]]["segmentation"]
            masks[idx_big]["segmentation"] = union
            masks[idx_big]["area"] = int(union.sum())
            masks_info = [r for r in masks_info
                          if r["Label"] != "lamiera" or r["Indice"] == idx_big]

        # ---- tabella + overlay ----
        df = pd.DataFrame(masks_info)
        st.dataframe(df, use_container_width=True)

        overlay = img_np.copy()
        for r in masks_info:
            seg_bin = (masks[r["Indice"]]["segmentation"]>0.5).astype(np.uint8)*255
            cnts,_ = cv2.findContours(seg_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            color = tuple(np.random.randint(0,255,3).tolist())
            for c in cnts:
                cv2.drawContours(overlay,[c],-1,color,2)
                x,y,w,h = cv2.boundingRect(c)
                cv2.putText(overlay,r["Label"],(x,y-5),
                            cv2.FONT_HERSHEY_SIMPLEX,0.5,color,2)
        st.image(overlay, caption="Overlay etichette", use_container_width=True)

        for r in masks_info:
            r["File"] = file.name
        all_rows += masks_info

        # aggiorna barra globale
        global_bar.progress(idx_file/len(uploaded),
                            text=f"Immagini elaborate: {idx_file}/{len(uploaded)}")

    global_bar.empty()

    if all_rows:
        csv = pd.DataFrame(all_rows).to_csv(index=False).encode()
        st.download_button("πŸ“₯ Scarica risultati (CSV)", csv,
                           "OR1_results.csv", "text/csv")