import torch import requests from io import BytesIO from PIL import Image import matplotlib.pyplot as plt import matplotlib.patches as patches import gradio as gr from transformers import AutoProcessor, OmDetTurboForObjectDetection device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Iniciando no dispositivo: {device.upper()}") processor = AutoProcessor.from_pretrained("omlab/omdet-turbo-swin-tiny-hf") model = OmDetTurboForObjectDetection.from_pretrained( "omlab/omdet-turbo-swin-tiny-hf" ).to(device) def plot_results(image, results): fig, ax = plt.subplots(1, figsize=(8, 6)) ax.imshow(image) ax.axis("off") labels = results.get("text_labels", results.get("classes", [])) for score, class_name, box in zip(results["scores"], labels, results["boxes"]): xmin, ymin, xmax, ymax = box.tolist() rect = patches.Rectangle( (xmin, ymin), xmax - xmin, ymax - ymin, linewidth=2, edgecolor='red', facecolor='none' ) ax.add_patch(rect) label = f"{class_name}: {score:.2f}" ax.text( xmin, ymin - 5, label, color='white', fontsize=10, weight='bold', backgroundcolor="red" ) return fig def detectar_objetos(url, classes_texto): try: image = Image.open(BytesIO(requests.get(url).content)).convert("RGB") classes = [c.strip() for c in classes_texto.split(",")] task = "Detect {}.".format(", ".join(classes)) inputs = processor( images=[image], text=[classes], task=[task], return_tensors="pt", ).to(device) with torch.no_grad(): outputs = model(**inputs) results = processor.post_process_grounded_object_detection( outputs, text_labels=[classes], target_sizes=[image.size[::-1]], threshold=0.2, nms_threshold=0.3, )[0] saida = "" labels = results.get("text_labels", results.get("classes", [])) for score, class_name, box in zip(results["scores"], labels, results["boxes"]): box_rounded = [round(b, 1) for b in box.tolist()] saida += f"{class_name} ({round(score.item(),2)}) -> {box_rounded}\n" fig = plot_results(image, results) return fig, saida except Exception as e: return None, f"Erro: {str(e)}" interface = gr.Interface( fn=detectar_objetos, inputs=[ gr.Textbox(label="URL da imagem"), gr.Textbox(label="Classes (separadas por vírgula)", value="cat, dog") ], outputs=[ gr.Plot(label="Imagem com detecção"), gr.Textbox(label="Resultados") ], title="Detecção de Objetos por URL", description="Cole uma URL de imagem e informe os objetos que deseja detectar." ) interface.launch()