eduardo4547 commited on
Commit
cbbef9c
·
verified ·
1 Parent(s): cf5a81c

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -168
app.py CHANGED
@@ -1,239 +1,177 @@
1
  import os
2
- import re
3
- from pathlib import Path
4
-
5
  import gradio as gr
6
  import numpy as np
7
  import torch
 
8
  from huggingface_hub import hf_hub_download
9
  from PIL import Image
10
- from transformers import CLIPModel, CLIPProcessor
11
  import spaces # <-- Importante para Hugging Face ZeroGPU
12
 
13
- # --- IMPORTACIONES DE SAM 2.1 ---
 
14
  from sam2.build_sam import build_sam2
15
- from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
16
 
17
- # --- CONFIGURACIÓN DE SAM 2.1 ---
 
18
  SAM2_REPO = "facebook/sam2.1-hiera-base-plus"
19
  CHECKPOINT_FILENAME = "sam2.1_hiera_base_plus.pt"
20
  SAM2_CONFIG = "configs/sam2.1/sam2.1_hiera_b+.yaml"
21
 
22
- CLIP_MODEL_NAME = "openai/clip-vit-base-patch32"
 
23
 
24
- # Volvemos a CUDA, ya que ahora cargaremos los modelos dentro de la función autorizada
25
  DEVICE = "cuda"
26
- CLIP_THRESHOLD = 0.26
27
 
28
- # Variables globales para los modelos (se inicializan vacías)
29
- sam2_model = None
30
- mask_generator = None
31
  clip_model = None
32
  clip_processor = None
33
 
34
  COLOR_PALETTE = [
35
- (255, 0, 0, 150), (0, 255, 0, 150), (0, 0, 255, 150), (255, 255, 0, 150),
36
- (0, 255, 255, 150), (255, 0, 255, 150), (255, 165, 0, 150), (128, 0, 128, 150),
 
 
 
 
37
  ]
38
 
39
- def download_checkpoint() -> str:
40
- """Descarga el modelo SAM 2.1 desde Hugging Face."""
41
  cache_dir = Path("./models")
42
  cache_dir.mkdir(parents=True, exist_ok=True)
43
  local_path = cache_dir / CHECKPOINT_FILENAME
44
-
45
  if not local_path.exists():
46
- print(f"Descargando {CHECKPOINT_FILENAME} desde Hugging Face...")
47
- local_path = Path(
48
- hf_hub_download(
49
- repo_id=SAM2_REPO,
50
- filename=CHECKPOINT_FILENAME,
51
- cache_dir=str(cache_dir),
52
- )
53
- )
54
- print("¡Descarga completada!")
55
  return str(local_path)
56
 
57
- def create_mask_overlay(image: Image.Image, masks: list[dict]) -> Image.Image:
58
- image = image.convert("RGBA")
59
- overlay_image = image.copy()
60
-
61
- # Ordenar por área para que las máscaras más pequeñas se dibujen encima
62
- sorted_masks = sorted(masks, key=(lambda x: x["area"]), reverse=True)
63
-
64
- for i, mask_data in enumerate(sorted_masks):
65
  color = COLOR_PALETTE[i % len(COLOR_PALETTE)]
66
- mask_bool = mask_data["segmentation"]
67
-
68
- # Convertir la matriz booleana a una imagen de escala de grises (L)
69
- mask_image = Image.fromarray(mask_bool.astype(np.uint8) * 255, mode="L")
70
-
71
- # Crear una capa del mismo tamaño con el color correspondiente
72
- color_overlay = Image.new("RGBA", image.size, color)
73
-
74
- # Pegar el color transparente SOBRE la imagen original, usando la máscara
75
  overlay_image.paste(color_overlay, (0, 0), mask_image)
76
 
77
  return overlay_image
78
 
79
- def mask_to_bbox(mask: np.ndarray):
80
- ys, xs = np.where(mask.astype(np.uint8))
81
- if ys.size == 0 or xs.size == 0:
82
- return None
83
- return int(xs.min()), int(ys.min()), int(xs.max()) + 1, int(ys.max()) + 1
84
-
85
- def crop_masked_region(image: Image.Image, mask: np.ndarray) -> Image.Image | None:
86
- bbox = mask_to_bbox(mask)
87
- if bbox is None:
88
- return None
89
- mask_img = Image.fromarray((mask.astype(np.uint8) * 255).astype(np.uint8), mode="L")
90
- background = Image.new("RGB", image.size, (127, 127, 127))
91
- masked = Image.composite(image, background, mask_img)
92
- return masked.crop(bbox)
93
-
94
- def normalize_features(features: torch.Tensor | object) -> torch.Tensor:
95
- if hasattr(features, "pooler_output"):
96
- features = features.pooler_output
97
- elif hasattr(features, "last_hidden_state"):
98
- features = features.last_hidden_state[:, 0, :]
99
- if not isinstance(features, torch.Tensor):
100
- raise RuntimeError("No se pudieron obtener características de CLIP.")
101
- return features / features.norm(dim=-1, keepdim=True)
102
-
103
- def compute_clip_features(images: list[Image.Image]):
104
- inputs = clip_processor(images=images, return_tensors="pt", padding=True).to(DEVICE)
105
- with torch.no_grad():
106
- features = clip_model.get_image_features(**inputs)
107
- return normalize_features(features)
108
-
109
- def select_masks_by_text(image: Image.Image, masks: list[dict], prompt: str) -> tuple[list[dict], list[tuple[str, float | None]]]:
110
- terms = [t.strip() for t in re.split(r"[,\n]+", prompt) if t.strip()]
111
- if len(terms) == 0:
112
- return [], []
113
-
114
- crops = []
115
- valid_masks = []
116
- for mask in masks:
117
- crop = crop_masked_region(image, mask["segmentation"])
118
- if crop is not None:
119
- valid_masks.append(mask)
120
- crops.append(crop)
121
-
122
- if len(crops) == 0:
123
- return [], [(term, None) for term in terms]
124
-
125
- image_features = compute_clip_features(crops)
126
- text_prompts = [f"A photo of a {term}." for term in terms]
127
- text_inputs = clip_processor(text=text_prompts, return_tensors="pt", padding=True).to(DEVICE)
128
-
129
- with torch.no_grad():
130
- text_features = clip_model.get_text_features(**text_inputs)
131
- text_features = normalize_features(text_features)
132
-
133
- similarities = (image_features @ text_features.T).cpu()
134
- selected_indices = set()
135
- hits = []
136
-
137
- for term_idx, term in enumerate(terms):
138
- scores = similarities[:, term_idx]
139
- valid_idxs = torch.where(scores >= CLIP_THRESHOLD)[0].tolist()
140
-
141
- if valid_idxs:
142
- selected_indices.update(valid_idxs)
143
- best_score = float(torch.max(scores[valid_idxs]).item())
144
- hits.append((term, best_score))
145
- else:
146
- hits.append((term, None))
147
-
148
- selected = [valid_masks[i] for i in sorted(selected_indices)]
149
- return selected, hits
150
-
151
  @spaces.GPU
152
  @torch.no_grad()
153
- def segmentar_imagen(imagen: Image.Image, texto: str):
154
- # Declaramos que vamos a modificar las variables globales
155
- global sam2_model, mask_generator, clip_model, clip_processor
156
 
157
- if imagen is None:
158
- return None, "Subí una imagen para segmentar."
159
-
160
- # --- CARGA DIFERIDA DE MODELOS ---
161
- # Esto solo se ejecuta la PRIMERA vez que el usuario hace clic.
162
- # Como ya estamos dentro de @spaces.GPU, el acceso a "cuda" es legal.
163
- if sam2_model is None:
164
- print("Inicializando modelos en GPU por primera vez...")
165
-
166
- # Activar precisiones mixtas para acelerar
167
  torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
168
  if torch.cuda.get_device_properties(0).major >= 8:
169
  torch.backends.cuda.matmul.allow_tf32 = True
170
  torch.backends.cudnn.allow_tf32 = True
171
 
 
 
172
  sam2_model = build_sam2(SAM2_CONFIG, checkpoint_path, device=DEVICE)
173
- mask_generator = SAM2AutomaticMaskGenerator(sam2_model)
174
 
175
- clip_model = CLIPModel.from_pretrained(CLIP_MODEL_NAME).to(DEVICE)
176
- clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME)
177
- print("¡Modelos cargados exitosamente!")
 
 
 
 
 
 
178
 
179
- # --- INICIO DEL PROCESAMIENTO ---
180
  imagen = imagen.convert("RGB")
181
  imagen_np = np.array(imagen)
182
-
183
- masks = mask_generator.generate(imagen_np)
184
 
185
- if len(masks) == 0:
186
- return None, "No se generaron máscaras para esta imagen."
187
-
188
- texto = texto.strip()
189
- if texto == "":
190
- overlay = create_mask_overlay(imagen, masks)
191
- return overlay, f"Generadas {len(masks)} máscaras con SAM 2.1 (Base+)."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
- selected_masks, hits = select_masks_by_text(imagen, masks, texto)
194
- if len(selected_masks) == 0:
195
- terms = [t.strip() for t in re.split(r"[,\n]+", texto) if t.strip()]
196
- return None, f"No se encontró un objeto que coincida con: {', '.join(terms)}."
197
 
198
- found_terms = [term for term, score in hits if score is not None]
199
- missing_terms = [term for term, score in hits if score is None]
200
- overlay = create_mask_overlay(imagen, selected_masks)
201
 
202
- message = f"Encontradas {len(selected_masks)} máscara(s) para: {', '.join(found_terms)}."
203
- if missing_terms:
204
- message += f" No se encontró: {', '.join(missing_terms)}."
205
- return overlay, message
206
 
207
  def crear_app():
208
- with gr.Blocks(title="Gradio + SAM 2.1 Demo") as demo:
209
- gr.Markdown("# 🎯 Segmentación automática con SAM 2.1 (Base Plus) y CLIP")
210
  gr.Markdown(
211
- "Subí una imagen y escribe una palabra para encontrar y segmentar el objeto deseado. Si dejas el texto vacío, se mostrarán todas las máscaras generadas.\n\n*Nota: La primera imagen tardará unos segundos más mientras se inicializan los modelos en la GPU.*"
 
212
  )
213
 
214
- with gr.Row(equal_height=True):
215
  with gr.Column(scale=1):
216
- imagen_entrada = gr.Image(type="pil", label="Subí tu imagen")
217
- texto_objeto = gr.Textbox(label="Buscar objeto", placeholder="Ej. perro, coche, persona")
218
- boton = gr.Button("Segmentar")
 
 
 
 
 
219
  with gr.Column(scale=1):
220
- imagen_salida = gr.Image(label="Resultado segmentado")
221
  estado = gr.Textbox(label="Estado", interactive=False)
222
 
223
  boton.click(
224
- fn=segmentar_imagen,
225
- inputs=[imagen_entrada, texto_objeto],
226
  outputs=[imagen_salida, estado],
227
  )
228
 
229
  return demo
230
 
231
  # --- INICIALIZACIÓN GLOBAL ---
232
- # Solo descargamos el peso del modelo al arrancar (esto no usa GPU)
233
- print("Verificando/Descargando archivo del modelo de SAM 2...")
234
- checkpoint_path = download_checkpoint()
235
 
236
- # Iniciar App (los modelos se cargarán al hacer clic en Segmentar)
237
  demo = crear_app()
238
 
239
  if __name__ == "__main__":
 
1
  import os
 
 
 
2
  import gradio as gr
3
  import numpy as np
4
  import torch
5
+ from pathlib import Path
6
  from huggingface_hub import hf_hub_download
7
  from PIL import Image
 
8
  import spaces # <-- Importante para Hugging Face ZeroGPU
9
 
10
+ # --- IMPORTACIONES DE MODELOS ---
11
+ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
12
  from sam2.build_sam import build_sam2
13
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
14
 
15
+ # --- CONFIGURACIÓN DE MODELOS ---
16
+ # SAM 2.1
17
  SAM2_REPO = "facebook/sam2.1-hiera-base-plus"
18
  CHECKPOINT_FILENAME = "sam2.1_hiera_base_plus.pt"
19
  SAM2_CONFIG = "configs/sam2.1/sam2.1_hiera_b+.yaml"
20
 
21
+ # GroundingDINO
22
+ GDINO_ID = "IDEA-Research/grounding-dino-base"
23
 
 
24
  DEVICE = "cuda"
 
25
 
26
+ # Variables globales para Lazy Loading (ZeroGPU)
27
+ sam2_predictor = None
28
+ gdino_model = None
29
  clip_model = None
30
  clip_processor = None
31
 
32
  COLOR_PALETTE = [
33
+ (0, 255, 255, 150), # Cian (queda muy bien para resaltar)
34
+ (255, 0, 255, 150), # Magenta
35
+ (255, 255, 0, 150), # Amarillo
36
+ (0, 255, 0, 150), # Verde
37
+ (255, 0, 0, 150), # Rojo
38
+ (0, 0, 255, 150), # Azul
39
  ]
40
 
41
+ def download_sam_checkpoint() -> str:
 
42
  cache_dir = Path("./models")
43
  cache_dir.mkdir(parents=True, exist_ok=True)
44
  local_path = cache_dir / CHECKPOINT_FILENAME
 
45
  if not local_path.exists():
46
+ print(f"Descargando {CHECKPOINT_FILENAME}...")
47
+ local_path = Path(hf_hub_download(repo_id=SAM2_REPO, filename=CHECKPOINT_FILENAME, cache_dir=str(cache_dir)))
 
 
 
 
 
 
 
48
  return str(local_path)
49
 
50
+ def create_mask_overlay(image: Image.Image, masks_np: np.ndarray) -> Image.Image:
51
+ """Superpone las máscaras booleanas (N, H, W) sobre la imagen."""
52
+ overlay_image = image.convert("RGBA").copy()
53
+
54
+ for i, mask_bool in enumerate(masks_np):
 
 
 
55
  color = COLOR_PALETTE[i % len(COLOR_PALETTE)]
56
+ mask_image = Image.fromarray((mask_bool * 255).astype(np.uint8), mode="L")
57
+ color_overlay = Image.new("RGBA", overlay_image.size, color)
 
 
 
 
 
 
 
58
  overlay_image.paste(color_overlay, (0, 0), mask_image)
59
 
60
  return overlay_image
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  @spaces.GPU
63
  @torch.no_grad()
64
+ def segmentar_con_dino_y_sam(imagen: Image.Image, texto: str, box_threshold: float):
65
+ global sam2_predictor, gdino_model, gdino_processor
 
66
 
67
+ if imagen is None or not texto.strip():
68
+ return None, "Sube una imagen y escribe qué quieres buscar."
69
+
70
+ # 1. LAZY LOADING: Inicializar modelos en la GPU la primera vez
71
+ if sam2_predictor is None:
72
+ print("Inicializando GroundingDINO y SAM 2.1 en GPU...")
 
 
 
 
73
  torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
74
  if torch.cuda.get_device_properties(0).major >= 8:
75
  torch.backends.cuda.matmul.allow_tf32 = True
76
  torch.backends.cudnn.allow_tf32 = True
77
 
78
+ # Cargar SAM 2.1 en modo Predictor (para cajas), no AutomaticMaskGenerator
79
+ checkpoint_path = download_sam_checkpoint()
80
  sam2_model = build_sam2(SAM2_CONFIG, checkpoint_path, device=DEVICE)
81
+ sam2_predictor = SAM2ImagePredictor(sam2_model)
82
 
83
+ # Cargar GroundingDINO
84
+ gdino_processor = AutoProcessor.from_pretrained(GDINO_ID)
85
+ gdino_model = AutoModelForZeroShotObjectDetection.from_pretrained(GDINO_ID).to(DEVICE)
86
+ print("¡Modelos listos!")
87
+
88
+ # Asegurarnos de que el texto termine en punto (GroundingDINO funciona mejor así)
89
+ texto = texto.strip()
90
+ if not texto.endswith("."):
91
+ texto += "."
92
 
 
93
  imagen = imagen.convert("RGB")
94
  imagen_np = np.array(imagen)
 
 
95
 
96
+ # 2. GROUNDING DINO: Encontrar las cajas delimitadoras
97
+ inputs = gdino_processor(images=imagen, text=texto, return_tensors="pt").to(DEVICE)
98
+ outputs = gdino_model(**inputs)
99
+
100
+ # Extraer las cajas con un umbral de confianza
101
+ results = gdino_processor.post_process_grounded_object_detection(
102
+ outputs,
103
+ inputs.input_ids,
104
+ box_threshold=box_threshold,
105
+ text_threshold=0.25,
106
+ target_sizes=[imagen.size[::-1]] # (alto, ancho)
107
+ )[0]
108
+
109
+ cajas = results["boxes"] # Tensor con coordenadas [x1, y1, x2, y2]
110
+ etiquetas = results["labels"]
111
+ scores = results["scores"]
112
+
113
+ if len(cajas) == 0:
114
+ return imagen, f"No se encontró nada para '{texto}' con el umbral actual ({box_threshold}). Intenta bajarlo."
115
+
116
+ # 3. SAM 2.1: Segmentar dentro de las cajas encontradas
117
+ sam2_predictor.set_image(imagen_np)
118
+
119
+ # SAM 2.1 requiere que las cajas sean un array numpy
120
+ input_boxes = cajas.cpu().numpy()
121
+
122
+ masks, _, _ = sam2_predictor.predict(
123
+ point_coords=None,
124
+ point_labels=None,
125
+ box=input_boxes,
126
+ multimask_output=False, # Queremos 1 máscara final por caja
127
+ )
128
+
129
+ # Las máscaras de SAM tienen forma (N, 1, H, W). Las aplanamos a (N, H, W)
130
+ masks = masks.squeeze(1)
131
 
132
+ # 4. SUPERPONER MÁSCARAS
133
+ resultado_img = create_mask_overlay(imagen, masks)
 
 
134
 
135
+ # Preparar el mensaje de estado
136
+ objetos_encontrados = [f"{label} ({score:.2f})" for label, score in zip(etiquetas, scores)]
137
+ mensaje = f"Encontrados {len(cajas)} objeto(s): {', '.join(objetos_encontrados)}"
138
 
139
+ return resultado_img, mensaje
 
 
 
140
 
141
  def crear_app():
142
+ with gr.Blocks(title="GroundingDINO + SAM 2.1") as demo:
143
+ gr.Markdown("# 🦖 GroundingDINO + 🎯 SAM 2.1 (Base Plus)")
144
  gr.Markdown(
145
+ "Segmentación de alta precisión basada en texto. Escribe lo que buscas (ej. `bed`, `lamp`, `pillow`).\n\n"
146
+ "*Nota: La primera imagen tardará unos segundos mientras se inicializa la GPU.*"
147
  )
148
 
149
+ with gr.Row():
150
  with gr.Column(scale=1):
151
+ imagen_entrada = gr.Image(type="pil", label="Sube tu foto")
152
+ texto_objeto = gr.Textbox(label="Buscar objeto (en inglés funciona mejor)", placeholder="Ej. bed, pillow, carpet")
153
+
154
+ # Deslizador para ajustar la sensibilidad de GroundingDINO
155
+ umbral = gr.Slider(minimum=0.1, maximum=0.9, value=0.3, step=0.05, label="Umbral de detección (Box Threshold)", info="Bájalo si no detecta el objeto, súbelo si detecta cosas incorrectas.")
156
+
157
+ boton = gr.Button("Segmentar", variant="primary")
158
+
159
  with gr.Column(scale=1):
160
+ imagen_salida = gr.Image(label="Resultado Segmentado")
161
  estado = gr.Textbox(label="Estado", interactive=False)
162
 
163
  boton.click(
164
+ fn=segmentar_con_dino_y_sam,
165
+ inputs=[imagen_entrada, texto_objeto, umbral],
166
  outputs=[imagen_salida, estado],
167
  )
168
 
169
  return demo
170
 
171
  # --- INICIALIZACIÓN GLOBAL ---
172
+ print("Descargando peso de SAM 2.1 al iniciar Space...")
173
+ download_sam_checkpoint()
 
174
 
 
175
  demo = crear_app()
176
 
177
  if __name__ == "__main__":