File size: 2,972 Bytes
e1f504d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
import requests
from io import BytesIO
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import gradio as gr

from transformers import AutoProcessor, OmDetTurboForObjectDetection

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Iniciando no dispositivo: {device.upper()}")

processor = AutoProcessor.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")
model = OmDetTurboForObjectDetection.from_pretrained(
    "omlab/omdet-turbo-swin-tiny-hf"
).to(device)

def plot_results(image, results):
    fig, ax = plt.subplots(1, figsize=(8, 6))
    ax.imshow(image)
    ax.axis("off")

    labels = results.get("text_labels", results.get("classes", []))

    for score, class_name, box in zip(results["scores"], labels, results["boxes"]):
        xmin, ymin, xmax, ymax = box.tolist()
        rect = patches.Rectangle(
            (xmin, ymin),
            xmax - xmin,
            ymax - ymin,
            linewidth=2,
            edgecolor='red',
            facecolor='none'
        )
        ax.add_patch(rect)

        label = f"{class_name}: {score:.2f}"
        ax.text(
            xmin,
            ymin - 5,
            label,
            color='white',
            fontsize=10,
            weight='bold',
            backgroundcolor="red"
        )

    return fig

def detectar_objetos(url, classes_texto):
    try:
        image = Image.open(BytesIO(requests.get(url).content)).convert("RGB")

        classes = [c.strip() for c in classes_texto.split(",")]
        task = "Detect {}.".format(", ".join(classes))

        inputs = processor(
            images=[image],
            text=[classes],
            task=[task],
            return_tensors="pt",
        ).to(device)

        with torch.no_grad():
            outputs = model(**inputs)

        results = processor.post_process_grounded_object_detection(
            outputs,
            text_labels=[classes],
            target_sizes=[image.size[::-1]],
            threshold=0.2,
            nms_threshold=0.3,
        )[0]

        saida = ""
        labels = results.get("text_labels", results.get("classes", []))

        for score, class_name, box in zip(results["scores"], labels, results["boxes"]):
            box_rounded = [round(b, 1) for b in box.tolist()]
            saida += f"{class_name} ({round(score.item(),2)}) -> {box_rounded}\n"

        fig = plot_results(image, results)

        return fig, saida

    except Exception as e:
        return None, f"Erro: {str(e)}"

interface = gr.Interface(
    fn=detectar_objetos,
    inputs=[
        gr.Textbox(label="URL da imagem"),
        gr.Textbox(label="Classes (separadas por vírgula)", value="cat, dog")
    ],
    outputs=[
        gr.Plot(label="Imagem com detecção"),
        gr.Textbox(label="Resultados")
    ],
    title="Detecção de Objetos por URL",
    description="Cole uma URL de imagem e informe os objetos que deseja detectar."
)

interface.launch()