valegro commited on
Commit
ebf76ce
·
verified ·
1 Parent(s): 7902b7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -47
app.py CHANGED
@@ -1,49 +1,96 @@
1
- import gradio as gr
2
  import numpy as np
3
- from PIL import Image, ImageDraw
 
 
 
4
  from huggingface_hub import hf_hub_download
5
- # carica SAM e GroundingDINO direttamente dall’HF Hub
6
- SAM_CHECKPOINT = hf_hub_download("facebook/sam-vit-base", "sam_vit_b.pth")
7
- GDINO_CONFIG = hf_hub_download("IDEA-Research/GroundingDINO", "GroundingDINO_SwinT_OGC.py")
8
- GDINO_CHECKPT = hf_hub_download("IDEA-Research/GroundingDINO", "groundingdino_swint_ogc.pth")
9
-
10
- # — pseudocodice di import —
11
- from segment_anything import sam_model
12
- from groundingdino.util.inference import load_model, predict
13
-
14
- sam = sam_model.load_from_checkpoint(SAM_CHECKPOINT)
15
- gdino = load_model(GDINO_CONFIG, GDINO_CHECKPT)
16
-
17
- def recognize(img, prompt, conf):
18
- masks = sam.segment(img) # 1. segmentazione zero‑shot
19
- out = Image.fromarray(img).convert("RGBA")
20
- draw = ImageDraw.Draw(out, "RGBA")
21
- results = []
22
- for m in masks:
23
- label, score = predict(gdino, img, m, prompt) # 2. classificazione zero‑shot
24
- if score<conf: continue
25
- yy, xx = np.where(m)
26
- bbox = (xx.min(), yy.min(), xx.max(), yy.max())
27
- area = int(m.sum())
28
- draw.rectangle(bbox, outline=(255,0,0,180), width=3)
29
- draw.text((bbox[0], bbox[1]-10), f"{label} {score:.2f}", fill=(255,0,0,180))
30
- results.append({"label":label, "score":score, "area":area, "bbox":bbox})
31
- return np.array(out), results
32
-
33
- app = gr.Interface(
34
- fn=recognize,
35
- inputs=[
36
- gr.Image(type="numpy", label="Upload Image"),
37
- gr.Textbox(label="Prompt (comma‑separated)"),
38
- gr.Slider(0,1,0.25, label="Confidence Threshold"),
39
- ],
40
- outputs=[
41
- gr.Image(label="Overlay"),
42
- gr.JSON(label="Detections")
43
- ],
44
- title="Zero‑Shot Component Recognition",
45
- description="Segmenta e classifica componenti meccanici da foto, senza training specifico."
46
- )
47
-
48
- if __name__=="__main__":
49
- app.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
  import numpy as np
3
+ import torch
4
+ from PIL import Image
5
+ import cv2
6
+ import matplotlib.pyplot as plt
7
  from huggingface_hub import hf_hub_download
8
+ from segment_anything import SamPredictor, sam_model_registry
9
+ from groundingdino.util.inference import load_model, predict, annotate
10
+
11
+ # Titolo dell'app
12
+ st.title("🔍 Riconoscimento Zero-Shot con GroundingDINO + SAM")
13
+
14
+ # Configurazione dei modelli da Hugging Face Hub
15
+ @st.cache_resource
16
+ def load_sam():
17
+ checkpoint = hf_hub_download(
18
+ repo_id="SegmentAnything/sam_vit_b",
19
+ filename="sam_vit_b_01ec64.pth"
20
+ )
21
+ model = sam_model_registry["vit_b"](checkpoint=checkpoint)
22
+ return SamPredictor(model.to("cuda" if torch.cuda.is_available() else "cpu"))
23
+
24
+ @st.cache_resource
25
+ def load_grounding_dino():
26
+ config_path = hf_hub_download(
27
+ repo_id="IDEA-Research/grounding-dino-tiny",
28
+ filename="GroundingDINO_SwinT_OGC.py"
29
+ )
30
+ checkpoint_path = hf_hub_download(
31
+ repo_id="IDEA-Research/grounding-dino-tiny",
32
+ filename="groundingdino_tiny.pth"
33
+ )
34
+ model = load_model(config_path, checkpoint_path)
35
+ return model
36
+
37
+ sam = load_sam()
38
+ grounding_dino = load_grounding_dino()
39
+
40
+ # Caricamento immagine da parte dell'utente
41
+ uploaded_image = st.file_uploader("📷 Carica un'immagine", type=['jpg', 'jpeg', 'png'])
42
+
43
+ prompt = st.text_input("📝 Inserisci le classi da riconoscere (separate da virgola)",
44
+ value="lamiera, foro circolare, vite, bullone, scanalatura")
45
+
46
+ if uploaded_image is not None:
47
+ image = Image.open(uploaded_image).convert("RGB")
48
+ img_array = np.array(image)
49
+
50
+ st.image(image, caption="Immagine caricata", use_column_width=True)
51
+
52
+ if st.button("▶️ Avvia riconoscimento"):
53
+ # GroundingDINO prediction
54
+ boxes, logits, phrases = predict(
55
+ model=grounding_dino,
56
+ image=img_array,
57
+ caption=prompt,
58
+ box_threshold=0.3,
59
+ text_threshold=0.25,
60
+ device="cuda" if torch.cuda.is_available() else "cpu"
61
+ )
62
+
63
+ annotated_frame = annotate(image_source=img_array, boxes=boxes, logits=logits, phrases=phrases)
64
+
65
+ st.subheader("Risultato GroundingDINO")
66
+ st.image(annotated_frame, caption="Annotazione GroundingDINO")
67
+
68
+ # SAM segmentation
69
+ sam.set_image(img_array)
70
+ H, W, _ = img_array.shape
71
+ boxes_scaled = boxes * torch.tensor([W, H, W, H], device=boxes.device)
72
+ boxes_scaled = boxes_scaled.cpu().numpy()
73
+
74
+ masks, scores, _ = sam.predict_torch(
75
+ point_coords=None,
76
+ point_labels=None,
77
+ boxes=torch.tensor(boxes_scaled, device=sam.device),
78
+ multimask_output=False,
79
+ )
80
+
81
+ # Visualizza maschere segmentate
82
+ st.subheader("Risultato Segment Anything (SAM)")
83
+ plt.figure(figsize=(10, 10))
84
+ plt.imshow(img_array)
85
+ for mask in masks:
86
+ mask_np = mask[0].cpu().numpy()
87
+ plt.contour(mask_np, colors='red', linewidths=1.5)
88
+ plt.axis('off')
89
+
90
+ st.pyplot(plt.gcf())
91
+ plt.close()
92
+
93
+ # Tabella risultati
94
+ st.subheader("🔖 Tabella Risultati")
95
+ result_data = [{"Classe": phrase, "Confidenza": round(logit.item(), 2)} for phrase, logit in zip(phrases, logits)]
96
+ st.table(result_data)