Spaces:
Sleeping
Sleeping
File size: 2,972 Bytes
e1f504d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 | import torch
import requests
from io import BytesIO
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import gradio as gr
from transformers import AutoProcessor, OmDetTurboForObjectDetection
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Iniciando no dispositivo: {device.upper()}")
processor = AutoProcessor.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")
model = OmDetTurboForObjectDetection.from_pretrained(
"omlab/omdet-turbo-swin-tiny-hf"
).to(device)
def plot_results(image, results):
fig, ax = plt.subplots(1, figsize=(8, 6))
ax.imshow(image)
ax.axis("off")
labels = results.get("text_labels", results.get("classes", []))
for score, class_name, box in zip(results["scores"], labels, results["boxes"]):
xmin, ymin, xmax, ymax = box.tolist()
rect = patches.Rectangle(
(xmin, ymin),
xmax - xmin,
ymax - ymin,
linewidth=2,
edgecolor='red',
facecolor='none'
)
ax.add_patch(rect)
label = f"{class_name}: {score:.2f}"
ax.text(
xmin,
ymin - 5,
label,
color='white',
fontsize=10,
weight='bold',
backgroundcolor="red"
)
return fig
def detectar_objetos(url, classes_texto):
try:
image = Image.open(BytesIO(requests.get(url).content)).convert("RGB")
classes = [c.strip() for c in classes_texto.split(",")]
task = "Detect {}.".format(", ".join(classes))
inputs = processor(
images=[image],
text=[classes],
task=[task],
return_tensors="pt",
).to(device)
with torch.no_grad():
outputs = model(**inputs)
results = processor.post_process_grounded_object_detection(
outputs,
text_labels=[classes],
target_sizes=[image.size[::-1]],
threshold=0.2,
nms_threshold=0.3,
)[0]
saida = ""
labels = results.get("text_labels", results.get("classes", []))
for score, class_name, box in zip(results["scores"], labels, results["boxes"]):
box_rounded = [round(b, 1) for b in box.tolist()]
saida += f"{class_name} ({round(score.item(),2)}) -> {box_rounded}\n"
fig = plot_results(image, results)
return fig, saida
except Exception as e:
return None, f"Erro: {str(e)}"
interface = gr.Interface(
fn=detectar_objetos,
inputs=[
gr.Textbox(label="URL da imagem"),
gr.Textbox(label="Classes (separadas por vírgula)", value="cat, dog")
],
outputs=[
gr.Plot(label="Imagem com detecção"),
gr.Textbox(label="Resultados")
],
title="Detecção de Objetos por URL",
description="Cole uma URL de imagem e informe os objetos que deseja detectar."
)
interface.launch() |