import gradio as gr from transformers import AutoImageProcessor, AutoModelForObjectDetection import torch from PIL import Image, ImageDraw, ImageFont # --- CONFIGURACIÓN --- MODEL_REPO = "acascanzal/practica2" BASE_PROCESSOR = "hustvl/yolos-tiny" THRESHOLD = 0.6 # --- CARGA DEL MODELO --- try: image_processor = AutoImageProcessor.from_pretrained(BASE_PROCESSOR) model = AutoModelForObjectDetection.from_pretrained(MODEL_REPO) except Exception as e: print(f"Error fatal: {e}") raise e # --- FUNCIÓN PRINCIPAL --- def detect_objects(image): if image is None: return None # Asegurar que la imagen es RGB (por si suben PNGs transparentes) image = image.convert("RGB") # 1. Preparar la imagen (El procesador se encarga del tamaño y normalización) inputs = image_processor(images=image, return_tensors="pt") # 2. Predicción with torch.no_grad(): outputs = model(**inputs) # 3. Convertir resultados (Cajas relativas -> Coordenadas reales) target_sizes = torch.tensor([image.size[::-1]]) results = image_processor.post_process_object_detection( outputs, target_sizes=target_sizes, threshold=THRESHOLD )[0] # 4. Dibujar en la imagen draw = ImageDraw.Draw(image) # Intentamos cargar una fuente grande, si no, usamos la default try: font = ImageFont.truetype("arial.ttf", 20) except: font = ImageFont.load_default() # Recorrer cada objeto detectado for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): # Coordenadas de la caja box = [round(i, 2) for i in box.tolist()] xmin, ymin, xmax, ymax = box # Obtener el nombre de la clase (ej. "gato", "perro") # Si tu modelo no guardó los nombres, mostrará "LABEL_0", "LABEL_1", etc. label_text = model.config.id2label.get(label.item(), f"Clase {label.item()}") # Crear texto: "Clase: 99%" display_text = f"{label_text}: {round(score.item(), 2)}" # DIBUJAR # Caja roja draw.rectangle(box, outline="red", width=3) # Fondo para el texto (para que se lea bien) text_bbox = draw.textbbox((xmin, ymin), display_text, font=font) draw.rectangle(text_bbox, fill="red") # Texto blanco draw.text((xmin, ymin), display_text, fill="white", font=font) return image # --- INTERFAZ --- interface = gr.Interface( fn=detect_objects, inputs=gr.Image(type="pil", label="Sube tu imagen aquí"), outputs=gr.Image(type="pil", label="Objetos Detectados"), title="Detector de Objetos (YOLO Tiny)", description=f"Modelo cargado desde: {MODEL_REPO}. Umbral de confianza: {THRESHOLD}", examples=[], # Puedes añadir fotos de ejemplo aquí si quieres ["foto1.jpg", "foto2.jpg"] theme="default" ) # Lanzar la app if __name__ == "__main__": interface.launch()