valegro commited on
Commit
bf3f0aa
·
verified ·
1 Parent(s): e182e5c

Rename pages/OR1_Riconoscimento.py to pages/OR1_visual_classification.py

Browse files
pages/OR1_Riconoscimento.py DELETED
@@ -1,80 +0,0 @@
1
- import streamlit as st, torch, cv2, numpy as np
2
- from PIL import Image
3
- from huggingface_hub import hf_hub_download
4
- from groundingdino.util.inference import load_model, predict
5
- from segment_anything import sam_model_registry, SamPredictor
6
-
7
- # ---------- MODELLI -----------------------------------------------------------
8
- @st.cache_resource(show_spinner=False)
9
- def load_grounding():
10
- cfg = "GroundingDINO_SwinT_OGC.py"
11
- ckpt = hf_hub_download("IDEA-Research/grounding-dino-swint-ogc",
12
- filename="groundingdino_swint_ogc.pth")
13
- return load_model(cfg, ckpt)
14
-
15
- @st.cache_resource(show_spinner=False)
16
- def load_sam():
17
- ckpt = hf_hub_download("facebook/sam-vit-base",
18
- filename="sam_vit_b_01ec64.pth")
19
- sam = sam_model_registry["vit_b"](checkpoint=ckpt)
20
- return SamPredictor(sam).to("cuda" if torch.cuda.is_available() else "cpu")
21
-
22
- dino = load_grounding()
23
- sam = load_sam()
24
-
25
- # ---------- UI ---------------------------------------------------------------
26
- st.header("OR1 – Riconoscimento zero‑shot (GroundingDINO + SAM)")
27
-
28
- img_file = st.file_uploader("Carica immagine", type=["jpg","jpeg","png","webp"])
29
- prompt = st.text_input("Classi da cercare (separate da virgola)",
30
- "lamiera, foro circolare, foro rettangolare, vite, bullone")
31
- box_th = st.slider("Soglia box (DINO)", 0.0,1.0,0.35,0.01)
32
- text_th = st.slider("Soglia testo (DINO)",0.0,1.0,0.25,0.01)
33
-
34
- if img_file:
35
- img = Image.open(img_file).convert("RGB")
36
- im_np = np.array(img)
37
- H,W = im_np.shape[:2]
38
-
39
- # 1. GroundingDINO
40
- boxes, labels = predict(
41
- model = dino,
42
- image = im_np,
43
- caption = prompt,
44
- box_threshold = box_th,
45
- text_threshold= text_th
46
- )
47
-
48
- if not len(boxes):
49
- st.warning("Nessun oggetto trovato – alza le soglie oppure modifica il prompt.")
50
- st.image(img, caption="Input")
51
- st.stop()
52
-
53
- # 2. SAM per maschere dettagliate
54
- boxes_xyxy = boxes * torch.tensor([W,H,W,H])
55
- sam.set_image(im_np)
56
-
57
- vis = im_np.copy()
58
- counter = {}
59
- for box,label in zip(boxes_xyxy, labels):
60
- m,_,_ = sam.predict(box=box.cpu().numpy(), multimask_output=False)
61
- if m[0].mean() < .005: # scarta blob minuscoli
62
- continue
63
- counter[label] = counter.get(label,0)+1
64
- # disegna contorni e label (preview)
65
- cnt,_ = cv2.findContours(m[0].astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
66
- color = tuple(int(x) for x in np.random.randint(0,255,3))
67
- cv2.drawContours(vis, cnt, -1, color, 2)
68
- x1,y1,x2,y2 = map(int, box)
69
- cv2.putText(vis, label, (x1, max(y1-5,10)), cv2.FONT_HERSHEY_SIMPLEX, .5, color,2)
70
-
71
- # 3. Output
72
- st.subheader("📊 Conteggio feature riconosciute")
73
- for k,v in counter.items():
74
- st.write(f"**{k}** : {v}")
75
-
76
- st.subheader("👁️‍🗨️ Preview")
77
- st.image(vis, caption="Mask outline + label")
78
- st.caption("Per GPU: Settings → Hardware → T4 small (o superiore)")
79
-
80
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pages/OR1_visual_classification.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import cv2
4
+ from PIL import Image
5
+ import pandas as pd
6
+
7
+ from utils import get_device, download_checkpoint
8
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
9
+ from transformers import CLIPProcessor, CLIPModel
10
+
11
+ st.set_page_config(page_title="OR1 – Riconoscimento", layout="wide")
12
+ st.title("🧩 OR1 – Riconoscimento visivo e classificazione funzionale")
13
+
14
+ device = get_device()
15
+ st.sidebar.success(f"Device: **{device}**")
16
+
17
+ # Parametri SAM
18
+ st.sidebar.header("Parametri SAM")
19
+ points_per_side = st.sidebar.slider("Points per side", 0, 128, 32)
20
+ pred_iou_thresh = st.sidebar.slider("Pred IoU Thresh", 0.0, 1.0, 0.8)
21
+ stability_score_thresh = st.sidebar.slider("Stability Score Thresh", 0.0, 1.0, 0.9)
22
+
23
+ @st.cache_resource
24
+ def load_sam():
25
+ url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
26
+ fname = "sam_vit_h_4b8939.pth"
27
+ download_checkpoint(url, fname)
28
+ sam = sam_model_registry["vit_h"](checkpoint=fname).to(device).eval()
29
+ return SamAutomaticMaskGenerator(
30
+ sam,
31
+ points_per_side=points_per_side,
32
+ pred_iou_thresh=pred_iou_thresh,
33
+ stability_score_thresh=stability_score_thresh
34
+ )
35
+
36
+ @st.cache_resource
37
+ def load_clip():
38
+ model_name = "laion/CLIP-ViT-L-14-laion2B-s32B-b82K"
39
+ clip_model = CLIPModel.from_pretrained(model_name).to(device)
40
+ clip_processor = CLIPProcessor.from_pretrained(model_name)
41
+ return clip_model, clip_processor
42
+
43
+ mask_generator, clip_model, clip_processor = load_sam(), *load_clip()
44
+
45
+ # Caricamento immagini
46
+ st.markdown("**1️⃣ Carica immagini** (JPG/PNG)")
47
+ uploaded = st.file_uploader("", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
48
+
49
+ # Etichette
50
+ st.markdown("**2️⃣ Inserisci etichette (es. lamiera, foro…)**")
51
+ default_labels = "lamiera, foro circolare, scanalatura rettangolare"
52
+ labels = [l.strip() for l in st.text_input("", default_labels).split(",") if l.strip()]
53
+
54
+ # Analisi
55
+ if uploaded and st.button("🔎 Analizza immagini"):
56
+ all_results = []
57
+ for file in uploaded:
58
+ img_pil = Image.open(file).convert("RGB")
59
+ img_np = np.array(img_pil)
60
+ st.subheader(f"📎 {file.name}")
61
+ st.image(img_pil, caption="Immagine originale", use_column_width=True)
62
+
63
+ with st.spinner("Segmentazione in corso…"):
64
+ masks = mask_generator.generate(img_np)
65
+ st.write(f"→ Segmenti trovati: {len(masks)}")
66
+
67
+ masks_info = []
68
+ for idx, m in enumerate(masks):
69
+ segm = m["segmentation"]
70
+ mask_bin = (segm * 255).astype(np.uint8)
71
+
72
+ inputs = clip_processor(text=labels, images=Image.fromarray(img_np), return_tensors="pt", padding=True)
73
+ inputs = {k: v.to(device) for k, v in inputs.items()}
74
+ out = clip_model(**inputs)
75
+ probs = out.logits_per_image.softmax(dim=1)
76
+ best_i = int(probs.argmax())
77
+ label = labels[best_i]
78
+ conf = float(probs[0, best_i])
79
+
80
+ masks_info.append({
81
+ "Indice": idx,
82
+ "Label": label,
83
+ "Confidence": round(conf, 3),
84
+ "Area(px)": int(m["area"])
85
+ })
86
+
87
+ df = pd.DataFrame(masks_info)
88
+ st.dataframe(df, use_container_width=True)
89
+
90
+ annotated = img_np.copy()
91
+ for info in masks_info:
92
+ segm_bin = (masks[info["Indice"]]["segmentation"] > 0.5).astype(np.uint8) * 255
93
+ contours, _ = cv2.findContours(segm_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
94
+ color = tuple(np.random.randint(0, 255, 3).tolist())
95
+ for cnt in contours:
96
+ cv2.drawContours(annotated, [cnt], -1, color, 2)
97
+ x, y, w, h = cv2.boundingRect(cnt)
98
+ cv2.putText(annotated, info["Label"], (x, y - 5),
99
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
100
+ st.image(annotated, caption="Overlay con etichette", use_column_width=True)
101
+
102
+ for r in masks_info:
103
+ r["File"] = file.name
104
+ all_results += masks_info
105
+
106
+ if all_results:
107
+ df_all = pd.DataFrame(all_results)
108
+ csv = df_all.to_csv(index=False).encode("utf-8")
109
+ st.download_button("📥 Scarica risultati (CSV)", csv, "or1_results.csv", "text/csv")