practica2 / app.py
acascanzal's picture
Update app.py
42da8c6 verified
import gradio as gr
from transformers import AutoImageProcessor, AutoModelForObjectDetection
import torch
from PIL import Image, ImageDraw, ImageFont
# --- CONFIGURACIÓN ---
MODEL_REPO = "acascanzal/practica2"
BASE_PROCESSOR = "hustvl/yolos-tiny"
THRESHOLD = 0.6
# --- CARGA DEL MODELO ---
try:
image_processor = AutoImageProcessor.from_pretrained(BASE_PROCESSOR)
model = AutoModelForObjectDetection.from_pretrained(MODEL_REPO)
except Exception as e:
print(f"Error fatal: {e}")
raise e
# --- FUNCIÓN PRINCIPAL ---
def detect_objects(image):
if image is None:
return None
# Asegurar que la imagen es RGB (por si suben PNGs transparentes)
image = image.convert("RGB")
# 1. Preparar la imagen (El procesador se encarga del tamaño y normalización)
inputs = image_processor(images=image, return_tensors="pt")
# 2. Predicción
with torch.no_grad():
outputs = model(**inputs)
# 3. Convertir resultados (Cajas relativas -> Coordenadas reales)
target_sizes = torch.tensor([image.size[::-1]])
results = image_processor.post_process_object_detection(
outputs,
target_sizes=target_sizes,
threshold=THRESHOLD
)[0]
# 4. Dibujar en la imagen
draw = ImageDraw.Draw(image)
# Intentamos cargar una fuente grande, si no, usamos la default
try:
font = ImageFont.truetype("arial.ttf", 20)
except:
font = ImageFont.load_default()
# Recorrer cada objeto detectado
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
# Coordenadas de la caja
box = [round(i, 2) for i in box.tolist()]
xmin, ymin, xmax, ymax = box
# Obtener el nombre de la clase (ej. "gato", "perro")
# Si tu modelo no guardó los nombres, mostrará "LABEL_0", "LABEL_1", etc.
label_text = model.config.id2label.get(label.item(), f"Clase {label.item()}")
# Crear texto: "Clase: 99%"
display_text = f"{label_text}: {round(score.item(), 2)}"
# DIBUJAR
# Caja roja
draw.rectangle(box, outline="red", width=3)
# Fondo para el texto (para que se lea bien)
text_bbox = draw.textbbox((xmin, ymin), display_text, font=font)
draw.rectangle(text_bbox, fill="red")
# Texto blanco
draw.text((xmin, ymin), display_text, fill="white", font=font)
return image
# --- INTERFAZ ---
interface = gr.Interface(
fn=detect_objects,
inputs=gr.Image(type="pil", label="Sube tu imagen aquí"),
outputs=gr.Image(type="pil", label="Objetos Detectados"),
title="Detector de Objetos (YOLO Tiny)",
description=f"Modelo cargado desde: {MODEL_REPO}. Umbral de confianza: {THRESHOLD}",
examples=[], # Puedes añadir fotos de ejemplo aquí si quieres ["foto1.jpg", "foto2.jpg"]
theme="default"
)
# Lanzar la app
if __name__ == "__main__":
interface.launch()