sanrapi's picture
Update app.py
d06bdab verified
import gradio as gr
from PIL import Image
from tqdm import tqdm
from typing import TypeVar, Tuple
import numpy as np
from rfdetr import RFDETRBase, RFDETRLarge
from rfdetr.util.coco_classes import COCO_CLASSES
import supervision as sv
from rfdetr.detr import RFDETR
import datetime
import os
import shutil
import uuid
ImageType = TypeVar("ImageType", Image.Image, np.ndarray)
COLOR = sv.ColorPalette.from_hex([
"#ffff00", "#ff9b00", "#ff8080", "#ff66b2", "#ff66ff", "#b266ff",
"#9999ff", "#3399ff", "#66ffff", "#33ff99", "#66ff66", "#99ff00"
])
MAX_VIDEO_LENGTH_SECONDS = 5
VIDEO_SCALE_FACTOR = 0.5
VIDEO_TARGET_DIRECTORY = "tmp"
def create_directory(directory_path: str) -> None:
"""Crea un directorio si no existe.
Args:
directory_path (str): Ruta del directorio a crear.
"""
if not os.path.exists(directory_path):
os.makedirs(directory_path)
def delete_directory(directory_path: str) -> None:
"""Elimina un directorio existente.
Args:
directory_path (str): Ruta del directorio a eliminar.
Raises:
FileNotFoundError: Si el directorio no existe.
PermissionError: Si no se tienen permisos para eliminarlo.
"""
if not os.path.exists(directory_path):
raise FileNotFoundError(f"El directorio '{directory_path}' no existe.")
try:
shutil.rmtree(directory_path)
except PermissionError:
raise PermissionError(f"Permiso denegado: No se puede eliminar '{directory_path}'.")
def generate_unique_name() -> str:
"""Genera un nombre 煤nico basado en la fecha y un UUID.
Returns:
str: Nombre 煤nico.
"""
current_datetime = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
unique_id = uuid.uuid4()
return f"{current_datetime}_{unique_id}"
def calculate_resolution_wh(image: ImageType) -> Tuple[int, int]:
"""Obtiene la resoluci贸n (ancho, alto) de una imagen.
Args:
image (ImageType): Imagen tipo PIL o NumPy.
Returns:
Tuple[int, int]: Resoluci贸n como (ancho, alto).
Raises:
ValueError: Si la imagen no tiene al menos dos dimensiones.
TypeError: Si el tipo de imagen no es soportado.
"""
if isinstance(image, Image.Image):
return image.size
elif isinstance(image, np.ndarray):
if image.ndim >= 2:
h, w = image.shape[:2]
return w, h
else:
raise ValueError("La imagen numpy debe tener al menos 2 dimensiones (alto, ancho).")
else:
raise TypeError("La imagen debe ser de tipo PIL.Image o numpy.ndarray.")
def load_model(resolution: int, checkpoint: str) -> RFDETR:
"""Carga un modelo RF-DETR seg煤n el checkpoint y la resoluci贸n.
Args:
resolution (int): Resoluci贸n de entrada del modelo.
checkpoint (str): Checkpoint a usar: 'base' o 'large'.
Returns:
RFDETR: Modelo cargado listo para inferencia.
Raises:
TypeError: Si el checkpoint no es v谩lido.
"""
if checkpoint == "base":
return RFDETRBase(resolution=resolution)
elif checkpoint == "large":
return RFDETRLarge(resolution=resolution)
raise TypeError("El checkpoint debe ser 'base' o 'large'.")
def detect_and_annotate(model: RFDETR, image: ImageType, confidence: float) -> ImageType:
"""Detecta objetos en una imagen y la anota con las detecciones.
Args:
model (RFDETR): Modelo de detecci贸n ya cargado.
image (ImageType): Imagen sobre la cual detectar objetos.
confidence (float): Umbral de confianza para las detecciones.
Returns:
ImageType: Imagen anotada con cajas y etiquetas.
"""
detections = model.predict(image, threshold=confidence)
resolution_wh = calculate_resolution_wh(image)
text_scale = sv.calculate_optimal_text_scale(resolution_wh=resolution_wh) - 0.4
thickness = sv.calculate_optimal_line_thickness(resolution_wh=resolution_wh) - 2
bbox_annotator = sv.BoxAnnotator(color=COLOR, thickness=thickness)
label_annotator = sv.LabelAnnotator(
color=COLOR,
text_color=sv.Color.BLACK,
text_scale=text_scale,
text_padding=1
)
labels = [
f"{COCO_CLASSES[class_id]} {conf:.2f}"
for class_id, conf in zip(detections.class_id, detections.confidence)
]
annotated_image = image.copy()
annotated_image = bbox_annotator.annotate(annotated_image, detections)
annotated_image = label_annotator.annotate(annotated_image, detections, labels)
return annotated_image
def image_processing_inference(input_image: Image.Image, confidence: float, resolution: int, checkpoint: str) -> Image.Image:
"""Funci贸n principal para inferencia sobre una imagen.
Args:
input_image (Image.Image): Imagen de entrada.
confidence (float): Umbral de confianza para detecci贸n.
resolution (int): Resoluci贸n objetivo del modelo.
checkpoint (str): Tipo de modelo ('base' o 'large').
Returns:
Image.Image: Imagen procesada con detecciones.
"""
input_image = input_image.resize((resolution, resolution))
model = load_model(resolution=resolution, checkpoint=checkpoint)
return detect_and_annotate(model=model, image=input_image, confidence=confidence)
def video_processing_inference(input_video: str, confidence: float, resolution: int, checkpoint: str, progress=gr.Progress(track_tqdm=True)) -> str:
"""Funci贸n principal para inferencia sobre un video.
Args:
input_video (str): Ruta al archivo de video.
confidence (float): Umbral de confianza.
resolution (int): Resoluci贸n del modelo.
checkpoint (str): Tipo de modelo ('base' o 'large').
progress (gr.Progress): Barra de progreso de Gradio.
Returns:
str: Ruta al video procesado con anotaciones.
"""
model = load_model(resolution=resolution, checkpoint=checkpoint)
name = generate_unique_name()
output_video = os.path.join(VIDEO_TARGET_DIRECTORY, f"{name}.mp4")
video_info = sv.VideoInfo.from_video_path(input_video)
video_info.width = int(video_info.width * VIDEO_SCALE_FACTOR)
video_info.height = int(video_info.height * VIDEO_SCALE_FACTOR)
total = min(video_info.total_frames, video_info.fps * MAX_VIDEO_LENGTH_SECONDS)
frames_generator = sv.get_video_frames_generator(input_video, end=total)
with sv.VideoSink(output_video, video_info=video_info) as sink:
for frame in tqdm(frames_generator, total=total):
annotated_frame = detect_and_annotate(
model=model,
image=frame,
confidence=confidence
)
annotated_frame = sv.scale_image(annotated_frame, VIDEO_SCALE_FACTOR)
sink.write_frame(annotated_frame)
return output_video
# Crear directorio temporal
create_directory(directory_path=VIDEO_TARGET_DIRECTORY)
# Interfaz de im谩genes
image_inputs = [
gr.Image(type="pil", label="Sube una imagen"),
gr.Slider(0.1, 1.0, value=0.5, step=0.05, label="Umbral de confianza"),
gr.Slider(320, 1400, step=8, value=728, label="Resoluci贸n de entrada"),
gr.Radio(choices=["base", "large"], value="base", label="Modelo (checkpoint)")
]
image_interface = gr.Interface(
fn=image_processing_inference,
inputs=image_inputs,
outputs=gr.Image(type="pil", label="Resultado con detecciones"),
title="Detecci贸n en Im谩genes",
description="Carga una imagen para detectar objetos con RF-DETR."
)
# Interfaz de videos
video_inputs = [
gr.Video(label="Sube un video"),
gr.Slider(0.1, 1.0, value=0.5, step=0.05, label="Umbral de confianza"),
gr.Slider(560, 1120, step=56, value=728, label="Resoluci贸n de entrada"),
gr.Radio(choices=["base", "large"], value="base", label="Modelo (checkpoint)")
]
video_interface = gr.Interface(
fn=video_processing_inference,
inputs=video_inputs,
outputs=gr.Video(label="Video con detecciones", height=600),
title="Detecci贸n en Videos",
description="Carga un video corto para detectar objetos con RF-DETR."
)
# Interfaz con pesta帽as
demo = gr.TabbedInterface(
interface_list=[image_interface, video_interface],
tab_names=["Imagen", "Video"]
)
if __name__ == "__main__":
demo.launch(debug=False, show_error=True,mcp_server=True)