dbeetar's picture
Update app.py
5ccead5 verified
Raw
History Blame Contribute Delete
5.27 kB
import torch
import gradio as gr
from PIL import Image, ImageDraw, ImageFont
from transformers import DetrImageProcessor, DetrForObjectDetection
# =========================
# Configuración del modelo
# =========================
MODEL_NAME = "facebook/detr-resnet-50"
processor = DetrImageProcessor.from_pretrained(MODEL_NAME)
model = DetrForObjectDetection.from_pretrained(MODEL_NAME)
model.eval()
# En Spaces normalmente es CPU, pero dejamos esto robusto
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model.to(DEVICE)
def _draw_boxes(image: Image.Image, detections, line_width: int = 3) -> Image.Image:
"""
Dibuja bounding boxes y etiquetas sobre una imagen PIL.
detections: lista de dicts con keys: label, score, box=[x1,y1,x2,y2]
"""
img = image.copy().convert("RGB")
draw = ImageDraw.Draw(img)
# Fuente (si no existe, usa default)
try:
font = ImageFont.truetype("DejaVuSans.ttf", 14)
except Exception:
font = ImageFont.load_default()
for det in detections:
x1, y1, x2, y2 = det["box"]
label = det["label"]
score = det["score"]
# Caja
draw.rectangle([x1, y1, x2, y2], width=line_width)
# Etiqueta
text = f"{label} {score:.2f}"
bbox = draw.textbbox((0, 0), text, font=font)
text_w = bbox[2] - bbox[0]
text_h = bbox[3] - bbox[1]
# Fondo del texto
draw.rectangle(
[x1, max(0, y1 - text_h - 6), x1 + text_w + 6, y1],
fill="black"
)
draw.text((x1 + 3, max(0, y1 - text_h - 3)), text, fill="white", font=font)
return img
def detect_objects(
image: Image.Image,
threshold: float = 0.7,
top_k: int = 10,
show_boxes: bool = True,
):
"""
Detecta objetos con DETR y devuelve:
1) Imagen anotada (opcional)
2) Tabla con detecciones (label, score, box)
3) Resumen textual (conteo por clase)
"""
if image is None:
return None, [], "Por favor sube una imagen."
# Preprocesamiento
inputs = processor(images=image, return_tensors="pt").to(DEVICE)
# Inferencia
with torch.no_grad():
outputs = model(**inputs)
# Post-procesamiento
target_sizes = torch.tensor([image.size[::-1]], device=DEVICE) # (alto, ancho)
results = processor.post_process_object_detection(
outputs, target_sizes=target_sizes, threshold=threshold
)[0]
labels = results["labels"].tolist()
scores = results["scores"].tolist()
boxes = results["boxes"].tolist()
if len(labels) == 0:
msg = (
f"No se detectaron objetos con threshold={threshold:.2f}. "
"Prueba bajándolo a 0.6–0.7 y usa una imagen con objetos claros (personas, carros, perros)."
)
return image, [], msg
# Convertimos a detecciones con nombre legible
detections = []
for label_id, score, box in zip(labels, scores, boxes):
label_name = model.config.id2label.get(label_id, str(label_id))
x1, y1, x2, y2 = box
detections.append(
{
"label": label_name,
"score": float(score),
"box": [float(x1), float(y1), float(x2), float(y2)],
}
)
# Ordenar por score y limitar top-k
detections = sorted(detections, key=lambda d: d["score"], reverse=True)[: int(top_k)]
# Tabla para Gradio (Dataframe acepta lista de listas)
table = [
[d["label"], round(d["score"], 3),
[round(v, 1) for v in d["box"]]]
for d in detections
]
# Resumen por clase
counts = {}
for d in detections:
counts[d["label"]] = counts.get(d["label"], 0) + 1
summary = "Resumen (top-k): " + ", ".join([f"{k}: {v}" for k, v in sorted(counts.items())])
# Imagen anotada
if show_boxes:
annotated = _draw_boxes(image, detections)
else:
annotated = image
return annotated, table, summary
# =========================
# Interfaz (UX mejorada)
# =========================
with gr.Blocks(title="Detección de Objetos con DETR (Transformers)") as demo:
gr.Markdown(
"""
# Detección de Objetos con DETR (Hugging Face Transformers)
Sube una imagen y el modelo **DETR** detectará objetos del dataset **COCO**.
**Tip:** Si no detecta nada, baja el *threshold* a **0.6–0.7**.
"""
)
with gr.Row():
inp_image = gr.Image(type="pil", label="Sube una imagen")
out_image = gr.Image(type="pil", label="Imagen con detecciones")
with gr.Row():
threshold = gr.Slider(0.1, 0.99, value=0.7, step=0.01, label="Threshold (confianza)")
top_k = gr.Slider(1, 50, value=10, step=1, label="Top-K detecciones")
show_boxes = gr.Checkbox(value=True, label="Mostrar bounding boxes")
btn = gr.Button("Detectar objetos")
out_table = gr.Dataframe(
headers=["Objeto", "Score", "Box [x1,y1,x2,y2]"],
label="Detecciones (ordenadas por score)",
wrap=True
)
out_summary = gr.Textbox(label="Resumen")
btn.click(
fn=detect_objects,
inputs=[inp_image, threshold, top_k, show_boxes],
outputs=[out_image, out_table, out_summary]
)
if __name__ == "__main__":
demo.launch()