Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import pandas as pd | |
| from utils import get_device, download_checkpoint | |
| from segment_anything import sam_model_registry, SamAutomaticMaskGenerator | |
| from transformers import CLIPProcessor, CLIPModel | |
| st.set_page_config(page_title="OR1 – Riconoscimento", layout="wide") | |
| st.title("🧩 OR1 – Riconoscimento visivo e classificazione funzionale") | |
| device = get_device() | |
| st.sidebar.success(f"Device: **{device}**") | |
| # Parametri SAM | |
| st.sidebar.header("Parametri SAM") | |
| points_per_side = st.sidebar.slider("Points per side", 0, 128, 32) | |
| pred_iou_thresh = st.sidebar.slider("Pred IoU Thresh", 0.0, 1.0, 0.8) | |
| stability_score_thresh = st.sidebar.slider("Stability Score Thresh", 0.0, 1.0, 0.9) | |
| def load_sam(): | |
| url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" | |
| fname = "sam_vit_h_4b8939.pth" | |
| download_checkpoint(url, fname) | |
| sam = sam_model_registry["vit_h"](checkpoint=fname).to(device).eval() | |
| return SamAutomaticMaskGenerator( | |
| sam, | |
| points_per_side=points_per_side, | |
| pred_iou_thresh=pred_iou_thresh, | |
| stability_score_thresh=stability_score_thresh | |
| ) | |
| def load_clip(): | |
| model_name = "laion/CLIP-ViT-L-14-laion2B-s32B-b82K" | |
| clip_model = CLIPModel.from_pretrained(model_name).to(device) | |
| clip_processor = CLIPProcessor.from_pretrained(model_name) | |
| return clip_model, clip_processor | |
| mask_generator, clip_model, clip_processor = load_sam(), *load_clip() | |
| # Caricamento immagini | |
| st.markdown("**1️⃣ Carica immagini** (JPG/PNG)") | |
| uploaded = st.file_uploader("", type=["jpg", "jpeg", "png"], accept_multiple_files=True) | |
| # Etichette | |
| st.markdown("**2️⃣ Inserisci etichette (es. lamiera, foro…)**") | |
| default_labels = "lamiera, foro circolare, scanalatura rettangolare" | |
| labels = [l.strip() for l in st.text_input("", default_labels).split(",") if l.strip()] | |
| # Analisi | |
| if uploaded and st.button("🔎 Analizza immagini"): | |
| all_results = [] | |
| for file in uploaded: | |
| img_pil = Image.open(file).convert("RGB") | |
| img_np = np.array(img_pil) | |
| st.subheader(f"📎 {file.name}") | |
| st.image(img_pil, caption="Immagine originale", use_column_width=True) | |
| with st.spinner("Segmentazione in corso…"): | |
| masks = mask_generator.generate(img_np) | |
| st.write(f"→ Segmenti trovati: {len(masks)}") | |
| masks_info = [] | |
| for idx, m in enumerate(masks): | |
| segm = m["segmentation"] | |
| mask_bin = (segm * 255).astype(np.uint8) | |
| inputs = clip_processor(text=labels, images=Image.fromarray(img_np), return_tensors="pt", padding=True) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| out = clip_model(**inputs) | |
| probs = out.logits_per_image.softmax(dim=1) | |
| best_i = int(probs.argmax()) | |
| label = labels[best_i] | |
| conf = float(probs[0, best_i]) | |
| masks_info.append({ | |
| "Indice": idx, | |
| "Label": label, | |
| "Confidence": round(conf, 3), | |
| "Area(px)": int(m["area"]) | |
| }) | |
| df = pd.DataFrame(masks_info) | |
| st.dataframe(df, use_container_width=True) | |
| annotated = img_np.copy() | |
| for info in masks_info: | |
| segm_bin = (masks[info["Indice"]]["segmentation"] > 0.5).astype(np.uint8) * 255 | |
| contours, _ = cv2.findContours(segm_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| color = tuple(np.random.randint(0, 255, 3).tolist()) | |
| for cnt in contours: | |
| cv2.drawContours(annotated, [cnt], -1, color, 2) | |
| x, y, w, h = cv2.boundingRect(cnt) | |
| cv2.putText(annotated, info["Label"], (x, y - 5), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2) | |
| st.image(annotated, caption="Overlay con etichette", use_column_width=True) | |
| for r in masks_info: | |
| r["File"] = file.name | |
| all_results += masks_info | |
| if all_results: | |
| df_all = pd.DataFrame(all_results) | |
| csv = df_all.to_csv(index=False).encode("utf-8") | |
| st.download_button("📥 Scarica risultati (CSV)", csv, "or1_results.csv", "text/csv") |