Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from PIL import Image, ImageDraw, ImageFont | |
| from transformers import DetrImageProcessor, DetrForObjectDetection | |
| # --- Cargar modelo una sola vez --- | |
| processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") | |
| model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50").eval() | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(DEVICE) | |
| ID2LABEL = model.config.id2label | |
| ALL_CLASSES = sorted(set(ID2LABEL.values())) | |
| def _annotate(image: Image.Image, detections): | |
| annotated = image.copy() | |
| draw = ImageDraw.Draw(annotated) | |
| try: | |
| font = ImageFont.truetype("arial.ttf", 16) | |
| except Exception: | |
| font = ImageFont.load_default() | |
| for d in detections: | |
| x0, y0, x1, y1 = d["box_xyxy"] | |
| label = d["label"] | |
| score = d["score"] | |
| txt = f"{label} {score:.2f}" | |
| # Caja | |
| draw.rectangle([x0, y0, x1, y1], outline="red", width=3) | |
| # Texto con fondo | |
| try: | |
| tw = draw.textlength(txt, font=font) # Pillow 10+ | |
| th = 16 | |
| except Exception: | |
| tw, th = font.getsize(txt) # fallback | |
| draw.rectangle([x0, y0 - th - 4, x0 + tw + 6, y0], fill="red") | |
| draw.text((x0 + 3, y0 - th - 2), txt, fill="white", font=font) | |
| return annotated | |
| def detect(image, threshold=0.9, classes=None, topk=0): | |
| """ | |
| Detecta objetos con DETR y retorna (imagen_anotada, lista_detecciones). | |
| Cada detección: {'label', 'score', 'box_xyxy'} con valores redondeados. | |
| """ | |
| if image is None: | |
| return None, [] | |
| inputs = processor(images=image, return_tensors="pt").to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| target_sizes = torch.tensor([image.size[::-1]], device=DEVICE) # (alto, ancho) | |
| results = processor.post_process_object_detection( | |
| outputs, target_sizes=target_sizes, threshold=float(threshold) | |
| )[0] | |
| dets = [] | |
| for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): | |
| dets.append({ | |
| "label_id": int(label), | |
| "label": ID2LABEL[int(label)], | |
| "score": float(score), | |
| "box_xyxy": [float(v) for v in box.tolist()] | |
| }) | |
| # Filtro por clases (opcional) | |
| if classes: | |
| allow = set(classes) | |
| dets = [d for d in dets if d["label"] in allow] | |
| # Top-K por score (0 = sin límite) | |
| if topk and int(topk) > 0: | |
| dets = sorted(dets, key=lambda d: d["score"], reverse=True)[:int(topk)] | |
| annotated = _annotate(image, dets) | |
| nice_dets = [ | |
| { | |
| "label": d["label"], | |
| "score": round(d["score"], 4), | |
| "box_xyxy": [round(v, 2) for v in d["box_xyxy"]], | |
| } | |
| for d in dets | |
| ] | |
| return annotated, nice_dets | |
| with gr.Blocks(title="DETR Object Detection") as demo: | |
| gr.Markdown("## DETR Object Detection (Transformers + Gradio)\nSube una imagen, ajusta umbral y filtros.") | |
| with gr.Row(): | |
| img = gr.Image(type="pil", label="Imagen de entrada") | |
| with gr.Column(): | |
| thr = gr.Slider(0.10, 0.99, value=0.90, step=0.01, label="Umbral (threshold)") | |
| classes = gr.CheckboxGroup(choices=ALL_CLASSES, label="Filtrar por clases (opcional)") | |
| topk = gr.Slider(0, 200, value=0, step=1, label="Top-K por score (0 = sin límite)") | |
| btn = gr.Button("Detectar", variant="primary") | |
| out_img = gr.Image(type="pil", label="Imagen anotada") | |
| out_json = gr.JSON(label="Detecciones (JSON)") | |
| btn.click(detect, inputs=[img, thr, classes, topk], outputs=[out_img, out_json]) | |
| if __name__ == "__main__": | |
| demo.launch() | |