import streamlit as st from PIL import Image import torch import numpy as np import matplotlib.pyplot as plt from io import BytesIO import time # Funções principais def load_model(model_name: str): """ Função para carregar o modelo escolhido. """ if model_name == "YOLOv5": return torch.hub.load('ultralytics/yolov5', 'yolov5s') # YOLOv5 light elif model_name == "Faster R-CNN": return torch.hub.load('pytorch/vision', 'fasterrcnn_resnet50_fpn', pretrained=True) elif model_name == "RetinaNet": return torch.hub.load('facebookresearch/detectron2', 'retinanet_r50_fpn', pretrained=True) else: raise ValueError("Modelo não suportado!") def detect_objects(model, image: Image, model_name: str): """ Função que executa a detecção de objetos com o modelo escolhido. """ img_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0).float() with torch.no_grad(): if model_name == "YOLOv5": results = model(img_tensor) else: results = model([img_tensor])[0] return results def save_image_with_annotations(image: Image, results, model_name: str): """ Função para salvar a imagem com anotações das detecções feitas. """ fig, ax = plt.subplots(1, 1, figsize=(12, 9)) ax.imshow(np.array(image)) for *box, conf, cls in results.xywh[0]: x1, y1, x2, y2 = box ax.add_patch(plt.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='r', facecolor='none')) ax.text(x1, y1, f'{results.names[int(cls)]} {conf:.2f}', color='r', fontsize=12, verticalalignment='top') output = BytesIO() plt.savefig(output, format='PNG') output.seek(0) return output def display_feedback(results): """ Exibe um feedback dinâmico com informações sobre as detecções. """ st.write(f"🔍 Objetos Detectados: {len(results.xywh[0])}") st.write(f"🔢 Classes: {results.names}") st.write(f"📊 Confiança: {results.xywh[0][:, 4]}") # Função de inicialização def initialize_app(): """ Função para inicializar a aplicação Streamlit. """ st.set_page_config(page_title="Detecção de Objetos com IA", page_icon="🤖", layout="centered") st.title("🧠 Elias Andrade - Detecção de Objetos com IA 🤖") # Seleção de modelo model_choice = st.selectbox("Escolha o modelo de detecção", ["YOLOv5", "Faster R-CNN", "RetinaNet"]) # Upload de imagem uploaded_file = st.file_uploader("📂 Carregue uma imagem para detecção", type=["jpg", "jpeg", "png"]) return model_choice, uploaded_file # Função principal para execução da aplicação def run_app(): model_choice, uploaded_file = initialize_app() # Caso o usuário carregue uma imagem if uploaded_file is not None: image = Image.open(uploaded_file) st.image(image, caption="Imagem Carregada", use_column_width=True) # Seleção de modelo model = load_model(model_choice) start_button = st.button("Iniciar Detecção 🔍") save_button = st.button("Salvar Imagem com Marcação 💾") if start_button: with st.spinner("Detectando objetos..."): results = detect_objects(model, image, model_choice) st.image(results.render()[0], caption="Imagem com Detecção", use_column_width=True) display_feedback(results) if save_button: output_image = save_image_with_annotations(image, results, model_choice) st.download_button( label="Baixar Imagem com Marcação 📥", data=output_image, file_name="imagem_com_marcacao.png", mime="image/png" ) # Função para a escolha de objetos positivos e negativos def mark_objects(results): """ Função para permitir ao usuário marcar objetos como positivos ou negativos. """ positive_objects = st.multiselect( "Escolha objetos para marcar como positivos ✅", options=results.names, default=[] ) negative_objects = st.multiselect( "Escolha objetos para marcar como negativos ❌", options=results.names, default=[] ) if positive_objects or negative_objects: st.write("📝 Objetos Marcados:") st.write(f"Positivos: {positive_objects}") st.write(f"Negativos: {negative_objects}") else: st.write("🔘 Não há objetos marcados ainda.") # Função para mostrar opções adicionais ao usuário def show_advanced_options(results): """ Função para exibir opções avançadas de manipulação e configuração do modelo. """ st.write("🔧 Opções Avançadas") # Opção de permitir que o usuário ajuste a confiança mínima de detecção confidence_threshold = st.slider( "Defina o limiar de confiança para detectar objetos", min_value=0.0, max_value=1.0, value=0.5, step=0.05 ) filtered_results = [r for r in results.xywh[0] if r[4] > confidence_threshold] st.write(f"🔎 Detecções filtradas: {len(filtered_results)}") # Exibir a imagem com detecções filtradas if len(filtered_results) > 0: fig, ax = plt.subplots(1, 1, figsize=(12, 9)) ax.imshow(np.array(image)) for *box, conf, cls in filtered_results: x1, y1, x2, y2 = box ax.add_patch(plt.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='r', facecolor='none')) ax.text(x1, y1, f'{results.names[int(cls)]} {conf:.2f}', color='r', fontsize=12, verticalalignment='top') st.pyplot(fig) # Função para execução completa da aplicação if __name__ == "__main__": run_app()