import torch import gradio as gr from PIL import Image, ImageDraw, ImageFont from transformers import DetrImageProcessor, DetrForObjectDetection # ========================= # Configuración del modelo # ========================= MODEL_NAME = "facebook/detr-resnet-50" processor = DetrImageProcessor.from_pretrained(MODEL_NAME) model = DetrForObjectDetection.from_pretrained(MODEL_NAME) model.eval() # En Spaces normalmente es CPU, pero dejamos esto robusto DEVICE = "cuda" if torch.cuda.is_available() else "cpu" model.to(DEVICE) def _draw_boxes(image: Image.Image, detections, line_width: int = 3) -> Image.Image: """ Dibuja bounding boxes y etiquetas sobre una imagen PIL. detections: lista de dicts con keys: label, score, box=[x1,y1,x2,y2] """ img = image.copy().convert("RGB") draw = ImageDraw.Draw(img) # Fuente (si no existe, usa default) try: font = ImageFont.truetype("DejaVuSans.ttf", 14) except Exception: font = ImageFont.load_default() for det in detections: x1, y1, x2, y2 = det["box"] label = det["label"] score = det["score"] # Caja draw.rectangle([x1, y1, x2, y2], width=line_width) # Etiqueta text = f"{label} {score:.2f}" bbox = draw.textbbox((0, 0), text, font=font) text_w = bbox[2] - bbox[0] text_h = bbox[3] - bbox[1] # Fondo del texto draw.rectangle( [x1, max(0, y1 - text_h - 6), x1 + text_w + 6, y1], fill="black" ) draw.text((x1 + 3, max(0, y1 - text_h - 3)), text, fill="white", font=font) return img def detect_objects( image: Image.Image, threshold: float = 0.7, top_k: int = 10, show_boxes: bool = True, ): """ Detecta objetos con DETR y devuelve: 1) Imagen anotada (opcional) 2) Tabla con detecciones (label, score, box) 3) Resumen textual (conteo por clase) """ if image is None: return None, [], "Por favor sube una imagen." # Preprocesamiento inputs = processor(images=image, return_tensors="pt").to(DEVICE) # Inferencia with torch.no_grad(): outputs = model(**inputs) # Post-procesamiento target_sizes = torch.tensor([image.size[::-1]], device=DEVICE) # (alto, ancho) results = processor.post_process_object_detection( outputs, target_sizes=target_sizes, threshold=threshold )[0] labels = results["labels"].tolist() scores = results["scores"].tolist() boxes = results["boxes"].tolist() if len(labels) == 0: msg = ( f"No se detectaron objetos con threshold={threshold:.2f}. " "Prueba bajándolo a 0.6–0.7 y usa una imagen con objetos claros (personas, carros, perros)." ) return image, [], msg # Convertimos a detecciones con nombre legible detections = [] for label_id, score, box in zip(labels, scores, boxes): label_name = model.config.id2label.get(label_id, str(label_id)) x1, y1, x2, y2 = box detections.append( { "label": label_name, "score": float(score), "box": [float(x1), float(y1), float(x2), float(y2)], } ) # Ordenar por score y limitar top-k detections = sorted(detections, key=lambda d: d["score"], reverse=True)[: int(top_k)] # Tabla para Gradio (Dataframe acepta lista de listas) table = [ [d["label"], round(d["score"], 3), [round(v, 1) for v in d["box"]]] for d in detections ] # Resumen por clase counts = {} for d in detections: counts[d["label"]] = counts.get(d["label"], 0) + 1 summary = "Resumen (top-k): " + ", ".join([f"{k}: {v}" for k, v in sorted(counts.items())]) # Imagen anotada if show_boxes: annotated = _draw_boxes(image, detections) else: annotated = image return annotated, table, summary # ========================= # Interfaz (UX mejorada) # ========================= with gr.Blocks(title="Detección de Objetos con DETR (Transformers)") as demo: gr.Markdown( """ # Detección de Objetos con DETR (Hugging Face Transformers) Sube una imagen y el modelo **DETR** detectará objetos del dataset **COCO**. **Tip:** Si no detecta nada, baja el *threshold* a **0.6–0.7**. """ ) with gr.Row(): inp_image = gr.Image(type="pil", label="Sube una imagen") out_image = gr.Image(type="pil", label="Imagen con detecciones") with gr.Row(): threshold = gr.Slider(0.1, 0.99, value=0.7, step=0.01, label="Threshold (confianza)") top_k = gr.Slider(1, 50, value=10, step=1, label="Top-K detecciones") show_boxes = gr.Checkbox(value=True, label="Mostrar bounding boxes") btn = gr.Button("Detectar objetos") out_table = gr.Dataframe( headers=["Objeto", "Score", "Box [x1,y1,x2,y2]"], label="Detecciones (ordenadas por score)", wrap=True ) out_summary = gr.Textbox(label="Resumen") btn.click( fn=detect_objects, inputs=[inp_image, threshold, top_k, show_boxes], outputs=[out_image, out_table, out_summary] ) if __name__ == "__main__": demo.launch()