File size: 4,290 Bytes
017f96b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import streamlit as st
import numpy as np
import cv2
from PIL import Image
import pandas as pd

from utils import get_device, download_checkpoint
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
from transformers import CLIPProcessor, CLIPModel

st.set_page_config(page_title="OR1 – Riconoscimento", layout="wide")
st.title("🧩 OR1 – Riconoscimento visivo e classificazione funzionale")

device = get_device()
st.sidebar.success(f"Device: **{device}**")

# Parametri SAM
st.sidebar.header("Parametri SAM")
points_per_side        = st.sidebar.slider("Points per side", 0, 128, 32)
pred_iou_thresh        = st.sidebar.slider("Pred IoU Thresh", 0.0, 1.0, 0.8)
stability_score_thresh = st.sidebar.slider("Stability Score Thresh", 0.0, 1.0, 0.9)

@st.cache_resource
def load_sam():
    url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
    fname = "sam_vit_h_4b8939.pth"
    download_checkpoint(url, fname)
    sam = sam_model_registry["vit_h"](checkpoint=fname).to(device).eval()
    return SamAutomaticMaskGenerator(
        sam,
        points_per_side=points_per_side,
        pred_iou_thresh=pred_iou_thresh,
        stability_score_thresh=stability_score_thresh
    )

@st.cache_resource
def load_clip():
    model_name = "laion/CLIP-ViT-L-14-laion2B-s32B-b82K"
    clip_model = CLIPModel.from_pretrained(model_name).to(device)
    clip_processor = CLIPProcessor.from_pretrained(model_name)
    return clip_model, clip_processor

mask_generator, clip_model, clip_processor = load_sam(), *load_clip()

# Caricamento immagini
st.markdown("**1️⃣ Carica immagini** (JPG/PNG)")
uploaded = st.file_uploader("", type=["jpg", "jpeg", "png"], accept_multiple_files=True)

# Etichette
st.markdown("**2️⃣ Inserisci etichette (es. lamiera, foro…)**")
default_labels = "lamiera, foro circolare, scanalatura rettangolare"
labels = [l.strip() for l in st.text_input("", default_labels).split(",") if l.strip()]

# Analisi
if uploaded and st.button("🔎 Analizza immagini"):
    all_results = []
    for file in uploaded:
        img_pil = Image.open(file).convert("RGB")
        img_np = np.array(img_pil)
        st.subheader(f"📎 {file.name}")
        st.image(img_pil, caption="Immagine originale", use_column_width=True)

        with st.spinner("Segmentazione in corso…"):
            masks = mask_generator.generate(img_np)
        st.write(f"→ Segmenti trovati: {len(masks)}")

        masks_info = []
        for idx, m in enumerate(masks):
            segm = m["segmentation"]
            mask_bin = (segm * 255).astype(np.uint8)

            inputs = clip_processor(text=labels, images=Image.fromarray(img_np), return_tensors="pt", padding=True)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            out = clip_model(**inputs)
            probs = out.logits_per_image.softmax(dim=1)
            best_i = int(probs.argmax())
            label = labels[best_i]
            conf = float(probs[0, best_i])

            masks_info.append({
                "Indice": idx,
                "Label": label,
                "Confidence": round(conf, 3),
                "Area(px)": int(m["area"])
            })

        df = pd.DataFrame(masks_info)
        st.dataframe(df, use_container_width=True)

        annotated = img_np.copy()
        for info in masks_info:
            segm_bin = (masks[info["Indice"]]["segmentation"] > 0.5).astype(np.uint8) * 255
            contours, _ = cv2.findContours(segm_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            color = tuple(np.random.randint(0, 255, 3).tolist())
            for cnt in contours:
                cv2.drawContours(annotated, [cnt], -1, color, 2)
                x, y, w, h = cv2.boundingRect(cnt)
                cv2.putText(annotated, info["Label"], (x, y - 5),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
        st.image(annotated, caption="Overlay con etichette", use_column_width=True)

        for r in masks_info:
            r["File"] = file.name
        all_results += masks_info

    if all_results:
        df_all = pd.DataFrame(all_results)
        csv = df_all.to_csv(index=False).encode("utf-8")
        st.download_button("📥 Scarica risultati (CSV)", csv, "or1_results.csv", "text/csv")