Nicolasg2c
Add author attribution comment at the top of app.py
57c1ead
Raw
History Blame Contribute Delete
3.86 kB
#Presentado por Nicolas Gerardo Gutierrez Carre帽o
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
from PIL import Image ## Para manejar im谩genes
import gradio as gr
import numpy as np
import io
# Cargar el procesador y el modelo
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
# Funci贸n para procesar la imagen
def detect_objects(image):
# Preprocesamiento de la imagen, en este paso se ajusta el tama帽o y se normaliza.
inputs = processor(images=image, return_tensors="pt")
# Detectar objetos, en este paso se obtiene la predicci贸n del modelo.
with torch.no_grad():
outputs = model(**inputs)
# Filtrar resultados, en este paso se obtienen las cajas delimitadoras, etiquetas y puntuaciones.
target_sizes = torch.tensor([image.size[::-1]]) # (alto, ancho)
# Se realiza el post-procesamiento para obtener los resultados finales, aplicando un umbral de puntuaci贸n de 0.9.
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
# Crear una lista de los resultados con nombre y puntuaci贸n, se obtienen las etiquetas de los objetos detectados.
labels = results["labels"]
# Se obtiene el puntaje de los objetos detectados.
scores = results["scores"]
# Se obtiene las cajas delimitadoras de los objetos detectados.
boxes = results["boxes"]
# Mostrar los objetos detectados
detected_objects = []
for score, label, box in zip(scores, labels, boxes):
#Se a帽aden los objetos detectados a la lista, incluyendo su etiqueta, nombre, puntuaci贸n y caja delimitadora.
detected_objects.append(f"Objeto: {label}: {model.config.id2label[label.item()]}, Score: {score:.2f}, Box: {box.tolist()}")
# # Visualizar los resultados de la detecci贸n, este codigo es para mostrar la imagen dibujando la caja delimitadora.
# plt.figure(figsize=(10, 10))
# plt.imshow(image)
# ax = plt.gca()
# for score, label, box in zip(scores, labels, boxes):
# x_min, y_min, x_max, y_max = box
# rect = Rectangle((x_min, y_min), x_max - x_min, y_max - y_min,
# fill=False, color='red', linewidth=2)
# ax.add_patch(rect)
# ax.text(x_min, y_min, f"{model.config.id2label[label.item()]} | Total score: {score:.2f}",
# fontsize=12, color='red', verticalalignment='top')
# plt.axis('off')
# plt.tight_layout()
# fig = plt.gcf()
# buf = io.BytesIO()
# fig.savefig(buf, format='png', bbox_inches='tight')
# buf.seek(0)
# plt.close()
#Retornar la imagen con las detecciones
# return Image.open(buf)
#En caso de que no se detecten objetos con un score mayor a 0.9, se informa al usuario.
if detected_objects.__len__() == 0:
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0)[0]
max_score = results["scores"].max()
max_label = results["labels"][results["scores"].argmax()]
return f"No se detectaron objetos con un score mayor a 0.9. El objeto m谩s cercano tiene un score de {max_score:.2f} y su label es {model.config.id2label[max_label.item()]}."
#Se retorna el arreglo con los strings de los objetos detectados.
return "\n".join(detected_objects)
# Crear la interfaz Gradio
def create_interface():
interface = gr.Interface(
fn=detect_objects,
inputs=gr.Image(type="pil"),
outputs=gr.Textbox(),
live=True,
title="Detecci贸n de Objetos con Transformers",
description="Sube una imagen y descubre qu茅 objetos se pueden detectar."
)
interface.launch()
if __name__ == "__main__":
create_interface()