File size: 3,685 Bytes
ba07818
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import gradio as gr
import torch
from PIL import Image, ImageDraw, ImageFont
from transformers import DetrImageProcessor, DetrForObjectDetection

# --- Cargar modelo una sola vez ---
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50").eval()

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model.to(DEVICE)
ID2LABEL = model.config.id2label
ALL_CLASSES = sorted(set(ID2LABEL.values()))

def _annotate(image: Image.Image, detections):
    annotated = image.copy()
    draw = ImageDraw.Draw(annotated)
    try:
        font = ImageFont.truetype("arial.ttf", 16)
    except Exception:
        font = ImageFont.load_default()

    for d in detections:
        x0, y0, x1, y1 = d["box_xyxy"]
        label = d["label"]
        score = d["score"]
        txt = f"{label} {score:.2f}"

        # Caja
        draw.rectangle([x0, y0, x1, y1], outline="red", width=3)

        # Texto con fondo
        try:
            tw = draw.textlength(txt, font=font)  # Pillow 10+
            th = 16
        except Exception:
            tw, th = font.getsize(txt)  # fallback
        draw.rectangle([x0, y0 - th - 4, x0 + tw + 6, y0], fill="red")
        draw.text((x0 + 3, y0 - th - 2), txt, fill="white", font=font)
    return annotated

def detect(image, threshold=0.9, classes=None, topk=0):
    """
    Detecta objetos con DETR y retorna (imagen_anotada, lista_detecciones).
    Cada detección: {'label', 'score', 'box_xyxy'} con valores redondeados.
    """
    if image is None:
        return None, []

    inputs = processor(images=image, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        outputs = model(**inputs)

    target_sizes = torch.tensor([image.size[::-1]], device=DEVICE)  # (alto, ancho)
    results = processor.post_process_object_detection(
        outputs, target_sizes=target_sizes, threshold=float(threshold)
    )[0]

    dets = []
    for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
        dets.append({
            "label_id": int(label),
            "label": ID2LABEL[int(label)],
            "score": float(score),
            "box_xyxy": [float(v) for v in box.tolist()]
        })

    # Filtro por clases (opcional)
    if classes:
        allow = set(classes)
        dets = [d for d in dets if d["label"] in allow]

    # Top-K por score (0 = sin límite)
    if topk and int(topk) > 0:
        dets = sorted(dets, key=lambda d: d["score"], reverse=True)[:int(topk)]

    annotated = _annotate(image, dets)
    nice_dets = [
        {
            "label": d["label"],
            "score": round(d["score"], 4),
            "box_xyxy": [round(v, 2) for v in d["box_xyxy"]],
        }
        for d in dets
    ]
    return annotated, nice_dets

with gr.Blocks(title="DETR Object Detection") as demo:
    gr.Markdown("## DETR Object Detection (Transformers + Gradio)\nSube una imagen, ajusta umbral y filtros.")
    with gr.Row():
        img = gr.Image(type="pil", label="Imagen de entrada")
        with gr.Column():
            thr = gr.Slider(0.10, 0.99, value=0.90, step=0.01, label="Umbral (threshold)")
            classes = gr.CheckboxGroup(choices=ALL_CLASSES, label="Filtrar por clases (opcional)")
            topk = gr.Slider(0, 200, value=0, step=1, label="Top-K por score (0 = sin límite)")
            btn = gr.Button("Detectar", variant="primary")

    out_img = gr.Image(type="pil", label="Imagen anotada")
    out_json = gr.JSON(label="Detecciones (JSON)")

    btn.click(detect, inputs=[img, thr, classes, topk], outputs=[out_img, out_json])

if __name__ == "__main__":
    demo.launch()