Upcycling_appAI / OR1_visual_classification.py
valegro's picture
Create OR1_visual_classification.py
017f96b verified
import streamlit as st
import numpy as np
import cv2
from PIL import Image
import pandas as pd
from utils import get_device, download_checkpoint
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
from transformers import CLIPProcessor, CLIPModel
st.set_page_config(page_title="OR1 – Riconoscimento", layout="wide")
st.title("🧩 OR1 – Riconoscimento visivo e classificazione funzionale")
device = get_device()
st.sidebar.success(f"Device: **{device}**")
# Parametri SAM
st.sidebar.header("Parametri SAM")
points_per_side = st.sidebar.slider("Points per side", 0, 128, 32)
pred_iou_thresh = st.sidebar.slider("Pred IoU Thresh", 0.0, 1.0, 0.8)
stability_score_thresh = st.sidebar.slider("Stability Score Thresh", 0.0, 1.0, 0.9)
@st.cache_resource
def load_sam():
url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
fname = "sam_vit_h_4b8939.pth"
download_checkpoint(url, fname)
sam = sam_model_registry["vit_h"](checkpoint=fname).to(device).eval()
return SamAutomaticMaskGenerator(
sam,
points_per_side=points_per_side,
pred_iou_thresh=pred_iou_thresh,
stability_score_thresh=stability_score_thresh
)
@st.cache_resource
def load_clip():
model_name = "laion/CLIP-ViT-L-14-laion2B-s32B-b82K"
clip_model = CLIPModel.from_pretrained(model_name).to(device)
clip_processor = CLIPProcessor.from_pretrained(model_name)
return clip_model, clip_processor
mask_generator, clip_model, clip_processor = load_sam(), *load_clip()
# Caricamento immagini
st.markdown("**1️⃣ Carica immagini** (JPG/PNG)")
uploaded = st.file_uploader("", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
# Etichette
st.markdown("**2️⃣ Inserisci etichette (es. lamiera, foro…)**")
default_labels = "lamiera, foro circolare, scanalatura rettangolare"
labels = [l.strip() for l in st.text_input("", default_labels).split(",") if l.strip()]
# Analisi
if uploaded and st.button("🔎 Analizza immagini"):
all_results = []
for file in uploaded:
img_pil = Image.open(file).convert("RGB")
img_np = np.array(img_pil)
st.subheader(f"📎 {file.name}")
st.image(img_pil, caption="Immagine originale", use_column_width=True)
with st.spinner("Segmentazione in corso…"):
masks = mask_generator.generate(img_np)
st.write(f"→ Segmenti trovati: {len(masks)}")
masks_info = []
for idx, m in enumerate(masks):
segm = m["segmentation"]
mask_bin = (segm * 255).astype(np.uint8)
inputs = clip_processor(text=labels, images=Image.fromarray(img_np), return_tensors="pt", padding=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
out = clip_model(**inputs)
probs = out.logits_per_image.softmax(dim=1)
best_i = int(probs.argmax())
label = labels[best_i]
conf = float(probs[0, best_i])
masks_info.append({
"Indice": idx,
"Label": label,
"Confidence": round(conf, 3),
"Area(px)": int(m["area"])
})
df = pd.DataFrame(masks_info)
st.dataframe(df, use_container_width=True)
annotated = img_np.copy()
for info in masks_info:
segm_bin = (masks[info["Indice"]]["segmentation"] > 0.5).astype(np.uint8) * 255
contours, _ = cv2.findContours(segm_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
color = tuple(np.random.randint(0, 255, 3).tolist())
for cnt in contours:
cv2.drawContours(annotated, [cnt], -1, color, 2)
x, y, w, h = cv2.boundingRect(cnt)
cv2.putText(annotated, info["Label"], (x, y - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
st.image(annotated, caption="Overlay con etichette", use_column_width=True)
for r in masks_info:
r["File"] = file.name
all_results += masks_info
if all_results:
df_all = pd.DataFrame(all_results)
csv = df_all.to_csv(index=False).encode("utf-8")
st.download_button("📥 Scarica risultati (CSV)", csv, "or1_results.csv", "text/csv")