import gradio as gr from PIL import Image from tqdm import tqdm from typing import TypeVar, Tuple import numpy as np from rfdetr import RFDETRBase, RFDETRLarge from rfdetr.util.coco_classes import COCO_CLASSES import supervision as sv from rfdetr.detr import RFDETR import datetime import os import shutil import uuid ImageType = TypeVar("ImageType", Image.Image, np.ndarray) COLOR = sv.ColorPalette.from_hex([ "#ffff00", "#ff9b00", "#ff8080", "#ff66b2", "#ff66ff", "#b266ff", "#9999ff", "#3399ff", "#66ffff", "#33ff99", "#66ff66", "#99ff00" ]) MAX_VIDEO_LENGTH_SECONDS = 5 VIDEO_SCALE_FACTOR = 0.5 VIDEO_TARGET_DIRECTORY = "tmp" def create_directory(directory_path: str) -> None: """Crea un directorio si no existe. Args: directory_path (str): Ruta del directorio a crear. """ if not os.path.exists(directory_path): os.makedirs(directory_path) def delete_directory(directory_path: str) -> None: """Elimina un directorio existente. Args: directory_path (str): Ruta del directorio a eliminar. Raises: FileNotFoundError: Si el directorio no existe. PermissionError: Si no se tienen permisos para eliminarlo. """ if not os.path.exists(directory_path): raise FileNotFoundError(f"El directorio '{directory_path}' no existe.") try: shutil.rmtree(directory_path) except PermissionError: raise PermissionError(f"Permiso denegado: No se puede eliminar '{directory_path}'.") def generate_unique_name() -> str: """Genera un nombre único basado en la fecha y un UUID. Returns: str: Nombre único. """ current_datetime = datetime.datetime.now().strftime("%Y%m%d%H%M%S") unique_id = uuid.uuid4() return f"{current_datetime}_{unique_id}" def calculate_resolution_wh(image: ImageType) -> Tuple[int, int]: """Obtiene la resolución (ancho, alto) de una imagen. Args: image (ImageType): Imagen tipo PIL o NumPy. Returns: Tuple[int, int]: Resolución como (ancho, alto). Raises: ValueError: Si la imagen no tiene al menos dos dimensiones. TypeError: Si el tipo de imagen no es soportado. """ if isinstance(image, Image.Image): return image.size elif isinstance(image, np.ndarray): if image.ndim >= 2: h, w = image.shape[:2] return w, h else: raise ValueError("La imagen numpy debe tener al menos 2 dimensiones (alto, ancho).") else: raise TypeError("La imagen debe ser de tipo PIL.Image o numpy.ndarray.") def load_model(resolution: int, checkpoint: str) -> RFDETR: """Carga un modelo RF-DETR según el checkpoint y la resolución. Args: resolution (int): Resolución de entrada del modelo. checkpoint (str): Checkpoint a usar: 'base' o 'large'. Returns: RFDETR: Modelo cargado listo para inferencia. Raises: TypeError: Si el checkpoint no es válido. """ if checkpoint == "base": return RFDETRBase(resolution=resolution) elif checkpoint == "large": return RFDETRLarge(resolution=resolution) raise TypeError("El checkpoint debe ser 'base' o 'large'.") def detect_and_annotate(model: RFDETR, image: ImageType, confidence: float) -> ImageType: """Detecta objetos en una imagen y la anota con las detecciones. Args: model (RFDETR): Modelo de detección ya cargado. image (ImageType): Imagen sobre la cual detectar objetos. confidence (float): Umbral de confianza para las detecciones. Returns: ImageType: Imagen anotada con cajas y etiquetas. """ detections = model.predict(image, threshold=confidence) resolution_wh = calculate_resolution_wh(image) text_scale = sv.calculate_optimal_text_scale(resolution_wh=resolution_wh) - 0.4 thickness = sv.calculate_optimal_line_thickness(resolution_wh=resolution_wh) - 2 bbox_annotator = sv.BoxAnnotator(color=COLOR, thickness=thickness) label_annotator = sv.LabelAnnotator( color=COLOR, text_color=sv.Color.BLACK, text_scale=text_scale, text_padding=1 ) labels = [ f"{COCO_CLASSES[class_id]} {conf:.2f}" for class_id, conf in zip(detections.class_id, detections.confidence) ] annotated_image = image.copy() annotated_image = bbox_annotator.annotate(annotated_image, detections) annotated_image = label_annotator.annotate(annotated_image, detections, labels) return annotated_image def image_processing_inference(input_image: Image.Image, confidence: float, resolution: int, checkpoint: str) -> Image.Image: """Función principal para inferencia sobre una imagen. Args: input_image (Image.Image): Imagen de entrada. confidence (float): Umbral de confianza para detección. resolution (int): Resolución objetivo del modelo. checkpoint (str): Tipo de modelo ('base' o 'large'). Returns: Image.Image: Imagen procesada con detecciones. """ input_image = input_image.resize((resolution, resolution)) model = load_model(resolution=resolution, checkpoint=checkpoint) return detect_and_annotate(model=model, image=input_image, confidence=confidence) def video_processing_inference(input_video: str, confidence: float, resolution: int, checkpoint: str, progress=gr.Progress(track_tqdm=True)) -> str: """Función principal para inferencia sobre un video. Args: input_video (str): Ruta al archivo de video. confidence (float): Umbral de confianza. resolution (int): Resolución del modelo. checkpoint (str): Tipo de modelo ('base' o 'large'). progress (gr.Progress): Barra de progreso de Gradio. Returns: str: Ruta al video procesado con anotaciones. """ model = load_model(resolution=resolution, checkpoint=checkpoint) name = generate_unique_name() output_video = os.path.join(VIDEO_TARGET_DIRECTORY, f"{name}.mp4") video_info = sv.VideoInfo.from_video_path(input_video) video_info.width = int(video_info.width * VIDEO_SCALE_FACTOR) video_info.height = int(video_info.height * VIDEO_SCALE_FACTOR) total = min(video_info.total_frames, video_info.fps * MAX_VIDEO_LENGTH_SECONDS) frames_generator = sv.get_video_frames_generator(input_video, end=total) with sv.VideoSink(output_video, video_info=video_info) as sink: for frame in tqdm(frames_generator, total=total): annotated_frame = detect_and_annotate( model=model, image=frame, confidence=confidence ) annotated_frame = sv.scale_image(annotated_frame, VIDEO_SCALE_FACTOR) sink.write_frame(annotated_frame) return output_video # Crear directorio temporal create_directory(directory_path=VIDEO_TARGET_DIRECTORY) # Interfaz de imágenes image_inputs = [ gr.Image(type="pil", label="Sube una imagen"), gr.Slider(0.1, 1.0, value=0.5, step=0.05, label="Umbral de confianza"), gr.Slider(320, 1400, step=8, value=728, label="Resolución de entrada"), gr.Radio(choices=["base", "large"], value="base", label="Modelo (checkpoint)") ] image_interface = gr.Interface( fn=image_processing_inference, inputs=image_inputs, outputs=gr.Image(type="pil", label="Resultado con detecciones"), title="Detección en Imágenes", description="Carga una imagen para detectar objetos con RF-DETR." ) # Interfaz de videos video_inputs = [ gr.Video(label="Sube un video"), gr.Slider(0.1, 1.0, value=0.5, step=0.05, label="Umbral de confianza"), gr.Slider(560, 1120, step=56, value=728, label="Resolución de entrada"), gr.Radio(choices=["base", "large"], value="base", label="Modelo (checkpoint)") ] video_interface = gr.Interface( fn=video_processing_inference, inputs=video_inputs, outputs=gr.Video(label="Video con detecciones", height=600), title="Detección en Videos", description="Carga un video corto para detectar objetos con RF-DETR." ) # Interfaz con pestañas demo = gr.TabbedInterface( interface_list=[image_interface, video_interface], tab_names=["Imagen", "Video"] ) if __name__ == "__main__": demo.launch(debug=False, show_error=True,mcp_server=True)