Spaces:
Paused
Paused
Rename pages/OR1_Riconoscimento.py to pages/OR1_visual_classification.py
Browse files- pages/OR1_Riconoscimento.py +0 -80
- pages/OR1_visual_classification.py +109 -0
pages/OR1_Riconoscimento.py
DELETED
|
@@ -1,80 +0,0 @@
|
|
| 1 |
-
import streamlit as st, torch, cv2, numpy as np
|
| 2 |
-
from PIL import Image
|
| 3 |
-
from huggingface_hub import hf_hub_download
|
| 4 |
-
from groundingdino.util.inference import load_model, predict
|
| 5 |
-
from segment_anything import sam_model_registry, SamPredictor
|
| 6 |
-
|
| 7 |
-
# ---------- MODELLI -----------------------------------------------------------
|
| 8 |
-
@st.cache_resource(show_spinner=False)
|
| 9 |
-
def load_grounding():
|
| 10 |
-
cfg = "GroundingDINO_SwinT_OGC.py"
|
| 11 |
-
ckpt = hf_hub_download("IDEA-Research/grounding-dino-swint-ogc",
|
| 12 |
-
filename="groundingdino_swint_ogc.pth")
|
| 13 |
-
return load_model(cfg, ckpt)
|
| 14 |
-
|
| 15 |
-
@st.cache_resource(show_spinner=False)
|
| 16 |
-
def load_sam():
|
| 17 |
-
ckpt = hf_hub_download("facebook/sam-vit-base",
|
| 18 |
-
filename="sam_vit_b_01ec64.pth")
|
| 19 |
-
sam = sam_model_registry["vit_b"](checkpoint=ckpt)
|
| 20 |
-
return SamPredictor(sam).to("cuda" if torch.cuda.is_available() else "cpu")
|
| 21 |
-
|
| 22 |
-
dino = load_grounding()
|
| 23 |
-
sam = load_sam()
|
| 24 |
-
|
| 25 |
-
# ---------- UI ---------------------------------------------------------------
|
| 26 |
-
st.header("OR1 – Riconoscimento zero‑shot (GroundingDINO + SAM)")
|
| 27 |
-
|
| 28 |
-
img_file = st.file_uploader("Carica immagine", type=["jpg","jpeg","png","webp"])
|
| 29 |
-
prompt = st.text_input("Classi da cercare (separate da virgola)",
|
| 30 |
-
"lamiera, foro circolare, foro rettangolare, vite, bullone")
|
| 31 |
-
box_th = st.slider("Soglia box (DINO)", 0.0,1.0,0.35,0.01)
|
| 32 |
-
text_th = st.slider("Soglia testo (DINO)",0.0,1.0,0.25,0.01)
|
| 33 |
-
|
| 34 |
-
if img_file:
|
| 35 |
-
img = Image.open(img_file).convert("RGB")
|
| 36 |
-
im_np = np.array(img)
|
| 37 |
-
H,W = im_np.shape[:2]
|
| 38 |
-
|
| 39 |
-
# 1. GroundingDINO
|
| 40 |
-
boxes, labels = predict(
|
| 41 |
-
model = dino,
|
| 42 |
-
image = im_np,
|
| 43 |
-
caption = prompt,
|
| 44 |
-
box_threshold = box_th,
|
| 45 |
-
text_threshold= text_th
|
| 46 |
-
)
|
| 47 |
-
|
| 48 |
-
if not len(boxes):
|
| 49 |
-
st.warning("Nessun oggetto trovato – alza le soglie oppure modifica il prompt.")
|
| 50 |
-
st.image(img, caption="Input")
|
| 51 |
-
st.stop()
|
| 52 |
-
|
| 53 |
-
# 2. SAM per maschere dettagliate
|
| 54 |
-
boxes_xyxy = boxes * torch.tensor([W,H,W,H])
|
| 55 |
-
sam.set_image(im_np)
|
| 56 |
-
|
| 57 |
-
vis = im_np.copy()
|
| 58 |
-
counter = {}
|
| 59 |
-
for box,label in zip(boxes_xyxy, labels):
|
| 60 |
-
m,_,_ = sam.predict(box=box.cpu().numpy(), multimask_output=False)
|
| 61 |
-
if m[0].mean() < .005: # scarta blob minuscoli
|
| 62 |
-
continue
|
| 63 |
-
counter[label] = counter.get(label,0)+1
|
| 64 |
-
# disegna contorni e label (preview)
|
| 65 |
-
cnt,_ = cv2.findContours(m[0].astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 66 |
-
color = tuple(int(x) for x in np.random.randint(0,255,3))
|
| 67 |
-
cv2.drawContours(vis, cnt, -1, color, 2)
|
| 68 |
-
x1,y1,x2,y2 = map(int, box)
|
| 69 |
-
cv2.putText(vis, label, (x1, max(y1-5,10)), cv2.FONT_HERSHEY_SIMPLEX, .5, color,2)
|
| 70 |
-
|
| 71 |
-
# 3. Output
|
| 72 |
-
st.subheader("📊 Conteggio feature riconosciute")
|
| 73 |
-
for k,v in counter.items():
|
| 74 |
-
st.write(f"**{k}** : {v}")
|
| 75 |
-
|
| 76 |
-
st.subheader("👁️🗨️ Preview")
|
| 77 |
-
st.image(vis, caption="Mask outline + label")
|
| 78 |
-
st.caption("Per GPU: Settings → Hardware → T4 small (o superiore)")
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pages/OR1_visual_classification.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import pandas as pd
|
| 6 |
+
|
| 7 |
+
from utils import get_device, download_checkpoint
|
| 8 |
+
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
|
| 9 |
+
from transformers import CLIPProcessor, CLIPModel
|
| 10 |
+
|
| 11 |
+
st.set_page_config(page_title="OR1 – Riconoscimento", layout="wide")
|
| 12 |
+
st.title("🧩 OR1 – Riconoscimento visivo e classificazione funzionale")
|
| 13 |
+
|
| 14 |
+
device = get_device()
|
| 15 |
+
st.sidebar.success(f"Device: **{device}**")
|
| 16 |
+
|
| 17 |
+
# Parametri SAM
|
| 18 |
+
st.sidebar.header("Parametri SAM")
|
| 19 |
+
points_per_side = st.sidebar.slider("Points per side", 0, 128, 32)
|
| 20 |
+
pred_iou_thresh = st.sidebar.slider("Pred IoU Thresh", 0.0, 1.0, 0.8)
|
| 21 |
+
stability_score_thresh = st.sidebar.slider("Stability Score Thresh", 0.0, 1.0, 0.9)
|
| 22 |
+
|
| 23 |
+
@st.cache_resource
|
| 24 |
+
def load_sam():
|
| 25 |
+
url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
| 26 |
+
fname = "sam_vit_h_4b8939.pth"
|
| 27 |
+
download_checkpoint(url, fname)
|
| 28 |
+
sam = sam_model_registry["vit_h"](checkpoint=fname).to(device).eval()
|
| 29 |
+
return SamAutomaticMaskGenerator(
|
| 30 |
+
sam,
|
| 31 |
+
points_per_side=points_per_side,
|
| 32 |
+
pred_iou_thresh=pred_iou_thresh,
|
| 33 |
+
stability_score_thresh=stability_score_thresh
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
@st.cache_resource
|
| 37 |
+
def load_clip():
|
| 38 |
+
model_name = "laion/CLIP-ViT-L-14-laion2B-s32B-b82K"
|
| 39 |
+
clip_model = CLIPModel.from_pretrained(model_name).to(device)
|
| 40 |
+
clip_processor = CLIPProcessor.from_pretrained(model_name)
|
| 41 |
+
return clip_model, clip_processor
|
| 42 |
+
|
| 43 |
+
mask_generator, clip_model, clip_processor = load_sam(), *load_clip()
|
| 44 |
+
|
| 45 |
+
# Caricamento immagini
|
| 46 |
+
st.markdown("**1️⃣ Carica immagini** (JPG/PNG)")
|
| 47 |
+
uploaded = st.file_uploader("", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
|
| 48 |
+
|
| 49 |
+
# Etichette
|
| 50 |
+
st.markdown("**2️⃣ Inserisci etichette (es. lamiera, foro…)**")
|
| 51 |
+
default_labels = "lamiera, foro circolare, scanalatura rettangolare"
|
| 52 |
+
labels = [l.strip() for l in st.text_input("", default_labels).split(",") if l.strip()]
|
| 53 |
+
|
| 54 |
+
# Analisi
|
| 55 |
+
if uploaded and st.button("🔎 Analizza immagini"):
|
| 56 |
+
all_results = []
|
| 57 |
+
for file in uploaded:
|
| 58 |
+
img_pil = Image.open(file).convert("RGB")
|
| 59 |
+
img_np = np.array(img_pil)
|
| 60 |
+
st.subheader(f"📎 {file.name}")
|
| 61 |
+
st.image(img_pil, caption="Immagine originale", use_column_width=True)
|
| 62 |
+
|
| 63 |
+
with st.spinner("Segmentazione in corso…"):
|
| 64 |
+
masks = mask_generator.generate(img_np)
|
| 65 |
+
st.write(f"→ Segmenti trovati: {len(masks)}")
|
| 66 |
+
|
| 67 |
+
masks_info = []
|
| 68 |
+
for idx, m in enumerate(masks):
|
| 69 |
+
segm = m["segmentation"]
|
| 70 |
+
mask_bin = (segm * 255).astype(np.uint8)
|
| 71 |
+
|
| 72 |
+
inputs = clip_processor(text=labels, images=Image.fromarray(img_np), return_tensors="pt", padding=True)
|
| 73 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 74 |
+
out = clip_model(**inputs)
|
| 75 |
+
probs = out.logits_per_image.softmax(dim=1)
|
| 76 |
+
best_i = int(probs.argmax())
|
| 77 |
+
label = labels[best_i]
|
| 78 |
+
conf = float(probs[0, best_i])
|
| 79 |
+
|
| 80 |
+
masks_info.append({
|
| 81 |
+
"Indice": idx,
|
| 82 |
+
"Label": label,
|
| 83 |
+
"Confidence": round(conf, 3),
|
| 84 |
+
"Area(px)": int(m["area"])
|
| 85 |
+
})
|
| 86 |
+
|
| 87 |
+
df = pd.DataFrame(masks_info)
|
| 88 |
+
st.dataframe(df, use_container_width=True)
|
| 89 |
+
|
| 90 |
+
annotated = img_np.copy()
|
| 91 |
+
for info in masks_info:
|
| 92 |
+
segm_bin = (masks[info["Indice"]]["segmentation"] > 0.5).astype(np.uint8) * 255
|
| 93 |
+
contours, _ = cv2.findContours(segm_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 94 |
+
color = tuple(np.random.randint(0, 255, 3).tolist())
|
| 95 |
+
for cnt in contours:
|
| 96 |
+
cv2.drawContours(annotated, [cnt], -1, color, 2)
|
| 97 |
+
x, y, w, h = cv2.boundingRect(cnt)
|
| 98 |
+
cv2.putText(annotated, info["Label"], (x, y - 5),
|
| 99 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
|
| 100 |
+
st.image(annotated, caption="Overlay con etichette", use_column_width=True)
|
| 101 |
+
|
| 102 |
+
for r in masks_info:
|
| 103 |
+
r["File"] = file.name
|
| 104 |
+
all_results += masks_info
|
| 105 |
+
|
| 106 |
+
if all_results:
|
| 107 |
+
df_all = pd.DataFrame(all_results)
|
| 108 |
+
csv = df_all.to_csv(index=False).encode("utf-8")
|
| 109 |
+
st.download_button("📥 Scarica risultati (CSV)", csv, "or1_results.csv", "text/csv")
|