Anestrom's picture
Update app.py (#2)
036274a
import torch
import requests
from io import BytesIO
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import gradio as gr
from transformers import AutoProcessor, OmDetTurboForObjectDetection
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Iniciando no dispositivo: {device.upper()}")
processor = AutoProcessor.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")
model = OmDetTurboForObjectDetection.from_pretrained(
"omlab/omdet-turbo-swin-tiny-hf"
).to(device)
def plot_results(image, results):
fig, ax = plt.subplots(1, figsize=(8, 6))
ax.imshow(image)
ax.axis("off")
labels = results.get("text_labels", results.get("classes", []))
for score, class_name, box in zip(results["scores"], labels, results["boxes"]):
xmin, ymin, xmax, ymax = box.tolist()
rect = patches.Rectangle(
(xmin, ymin),
xmax - xmin,
ymax - ymin,
linewidth=2,
edgecolor='red',
facecolor='none'
)
ax.add_patch(rect)
label = f"{class_name}: {score:.2f}"
ax.text(
xmin,
ymin - 5,
label,
color='white',
fontsize=10,
weight='bold',
backgroundcolor="red"
)
return fig
def detectar_objetos(url, classes_texto):
try:
image = Image.open(BytesIO(requests.get(url).content)).convert("RGB")
classes = [c.strip() for c in classes_texto.split(",")]
task = "Detect {}.".format(", ".join(classes))
inputs = processor(
images=[image],
text=[classes],
task=[task],
return_tensors="pt",
).to(device)
with torch.no_grad():
outputs = model(**inputs)
results = processor.post_process_grounded_object_detection(
outputs,
text_labels=[classes],
target_sizes=[image.size[::-1]],
threshold=0.2,
nms_threshold=0.3,
)[0]
saida = ""
labels = results.get("text_labels", results.get("classes", []))
for score, class_name, box in zip(results["scores"], labels, results["boxes"]):
box_rounded = [round(b, 1) for b in box.tolist()]
saida += f"{class_name} ({round(score.item(),2)}) -> {box_rounded}\n"
fig = plot_results(image, results)
return fig, saida
except Exception as e:
return None, f"Erro: {str(e)}"
interface = gr.Interface(
fn=detectar_objetos,
inputs=[
gr.Textbox(label="URL da imagem"),
gr.Textbox(label="Classes (separadas por vírgula)", value="cat, dog")
],
outputs=[
gr.Plot(label="Imagem com detecção"),
gr.Textbox(label="Resultados")
],
title="Detecção de Objetos por URL",
description="Cole uma URL de imagem e informe os objetos que deseja detectar."
)
interface.launch()