Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoImageProcessor, AutoModelForObjectDetection | |
| import torch | |
| from PIL import Image, ImageDraw, ImageFont | |
| # --- CONFIGURACIÓN --- | |
| MODEL_REPO = "acascanzal/practica2" | |
| BASE_PROCESSOR = "hustvl/yolos-tiny" | |
| THRESHOLD = 0.6 | |
| # --- CARGA DEL MODELO --- | |
| try: | |
| image_processor = AutoImageProcessor.from_pretrained(BASE_PROCESSOR) | |
| model = AutoModelForObjectDetection.from_pretrained(MODEL_REPO) | |
| except Exception as e: | |
| print(f"Error fatal: {e}") | |
| raise e | |
| # --- FUNCIÓN PRINCIPAL --- | |
| def detect_objects(image): | |
| if image is None: | |
| return None | |
| # Asegurar que la imagen es RGB (por si suben PNGs transparentes) | |
| image = image.convert("RGB") | |
| # 1. Preparar la imagen (El procesador se encarga del tamaño y normalización) | |
| inputs = image_processor(images=image, return_tensors="pt") | |
| # 2. Predicción | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # 3. Convertir resultados (Cajas relativas -> Coordenadas reales) | |
| target_sizes = torch.tensor([image.size[::-1]]) | |
| results = image_processor.post_process_object_detection( | |
| outputs, | |
| target_sizes=target_sizes, | |
| threshold=THRESHOLD | |
| )[0] | |
| # 4. Dibujar en la imagen | |
| draw = ImageDraw.Draw(image) | |
| # Intentamos cargar una fuente grande, si no, usamos la default | |
| try: | |
| font = ImageFont.truetype("arial.ttf", 20) | |
| except: | |
| font = ImageFont.load_default() | |
| # Recorrer cada objeto detectado | |
| for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): | |
| # Coordenadas de la caja | |
| box = [round(i, 2) for i in box.tolist()] | |
| xmin, ymin, xmax, ymax = box | |
| # Obtener el nombre de la clase (ej. "gato", "perro") | |
| # Si tu modelo no guardó los nombres, mostrará "LABEL_0", "LABEL_1", etc. | |
| label_text = model.config.id2label.get(label.item(), f"Clase {label.item()}") | |
| # Crear texto: "Clase: 99%" | |
| display_text = f"{label_text}: {round(score.item(), 2)}" | |
| # DIBUJAR | |
| # Caja roja | |
| draw.rectangle(box, outline="red", width=3) | |
| # Fondo para el texto (para que se lea bien) | |
| text_bbox = draw.textbbox((xmin, ymin), display_text, font=font) | |
| draw.rectangle(text_bbox, fill="red") | |
| # Texto blanco | |
| draw.text((xmin, ymin), display_text, fill="white", font=font) | |
| return image | |
| # --- INTERFAZ --- | |
| interface = gr.Interface( | |
| fn=detect_objects, | |
| inputs=gr.Image(type="pil", label="Sube tu imagen aquí"), | |
| outputs=gr.Image(type="pil", label="Objetos Detectados"), | |
| title="Detector de Objetos (YOLO Tiny)", | |
| description=f"Modelo cargado desde: {MODEL_REPO}. Umbral de confianza: {THRESHOLD}", | |
| examples=[], # Puedes añadir fotos de ejemplo aquí si quieres ["foto1.jpg", "foto2.jpg"] | |
| theme="default" | |
| ) | |
| # Lanzar la app | |
| if __name__ == "__main__": | |
| interface.launch() |