Spaces:
Sleeping
Sleeping
File size: 5,800 Bytes
5177584 f9739d9 5177584 f9739d9 5177584 f9739d9 5177584 f9739d9 5177584 f9739d9 5177584 f9739d9 5177584 f9739d9 5177584 f9739d9 5177584 f9739d9 5177584 f9739d9 5177584 f9739d9 5177584 f9739d9 5177584 f9739d9 5177584 f9739d9 5177584 f9739d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import streamlit as st
from PIL import Image
import torch
import numpy as np
import matplotlib.pyplot as plt
from io import BytesIO
import time
# Funções principais
def load_model(model_name: str):
"""
Função para carregar o modelo escolhido.
"""
if model_name == "YOLOv5":
return torch.hub.load('ultralytics/yolov5', 'yolov5s') # YOLOv5 light
elif model_name == "Faster R-CNN":
return torch.hub.load('pytorch/vision', 'fasterrcnn_resnet50_fpn', pretrained=True)
elif model_name == "RetinaNet":
return torch.hub.load('facebookresearch/detectron2', 'retinanet_r50_fpn', pretrained=True)
else:
raise ValueError("Modelo não suportado!")
def detect_objects(model, image: Image, model_name: str):
"""
Função que executa a detecção de objetos com o modelo escolhido.
"""
img_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0).float()
with torch.no_grad():
if model_name == "YOLOv5":
results = model(img_tensor)
else:
results = model([img_tensor])[0]
return results
def save_image_with_annotations(image: Image, results, model_name: str):
"""
Função para salvar a imagem com anotações das detecções feitas.
"""
fig, ax = plt.subplots(1, 1, figsize=(12, 9))
ax.imshow(np.array(image))
for *box, conf, cls in results.xywh[0]:
x1, y1, x2, y2 = box
ax.add_patch(plt.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='r', facecolor='none'))
ax.text(x1, y1, f'{results.names[int(cls)]} {conf:.2f}', color='r', fontsize=12, verticalalignment='top')
output = BytesIO()
plt.savefig(output, format='PNG')
output.seek(0)
return output
def display_feedback(results):
"""
Exibe um feedback dinâmico com informações sobre as detecções.
"""
st.write(f"🔍 Objetos Detectados: {len(results.xywh[0])}")
st.write(f"🔢 Classes: {results.names}")
st.write(f"📊 Confiança: {results.xywh[0][:, 4]}")
# Função de inicialização
def initialize_app():
"""
Função para inicializar a aplicação Streamlit.
"""
st.set_page_config(page_title="Detecção de Objetos com IA", page_icon="🤖", layout="centered")
st.title("🧠 Elias Andrade - Detecção de Objetos com IA 🤖")
# Seleção de modelo
model_choice = st.selectbox("Escolha o modelo de detecção", ["YOLOv5", "Faster R-CNN", "RetinaNet"])
# Upload de imagem
uploaded_file = st.file_uploader("📂 Carregue uma imagem para detecção", type=["jpg", "jpeg", "png"])
return model_choice, uploaded_file
# Função principal para execução da aplicação
def run_app():
model_choice, uploaded_file = initialize_app()
# Caso o usuário carregue uma imagem
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption="Imagem Carregada", use_column_width=True)
# Seleção de modelo
model = load_model(model_choice)
start_button = st.button("Iniciar Detecção 🔍")
save_button = st.button("Salvar Imagem com Marcação 💾")
if start_button:
with st.spinner("Detectando objetos..."):
results = detect_objects(model, image, model_choice)
st.image(results.render()[0], caption="Imagem com Detecção", use_column_width=True)
display_feedback(results)
if save_button:
output_image = save_image_with_annotations(image, results, model_choice)
st.download_button(
label="Baixar Imagem com Marcação 📥",
data=output_image,
file_name="imagem_com_marcacao.png",
mime="image/png"
)
# Função para a escolha de objetos positivos e negativos
def mark_objects(results):
"""
Função para permitir ao usuário marcar objetos como positivos ou negativos.
"""
positive_objects = st.multiselect(
"Escolha objetos para marcar como positivos ✅",
options=results.names, default=[]
)
negative_objects = st.multiselect(
"Escolha objetos para marcar como negativos ❌",
options=results.names, default=[]
)
if positive_objects or negative_objects:
st.write("📝 Objetos Marcados:")
st.write(f"Positivos: {positive_objects}")
st.write(f"Negativos: {negative_objects}")
else:
st.write("🔘 Não há objetos marcados ainda.")
# Função para mostrar opções adicionais ao usuário
def show_advanced_options(results):
"""
Função para exibir opções avançadas de manipulação e configuração do modelo.
"""
st.write("🔧 Opções Avançadas")
# Opção de permitir que o usuário ajuste a confiança mínima de detecção
confidence_threshold = st.slider(
"Defina o limiar de confiança para detectar objetos",
min_value=0.0, max_value=1.0, value=0.5, step=0.05
)
filtered_results = [r for r in results.xywh[0] if r[4] > confidence_threshold]
st.write(f"🔎 Detecções filtradas: {len(filtered_results)}")
# Exibir a imagem com detecções filtradas
if len(filtered_results) > 0:
fig, ax = plt.subplots(1, 1, figsize=(12, 9))
ax.imshow(np.array(image))
for *box, conf, cls in filtered_results:
x1, y1, x2, y2 = box
ax.add_patch(plt.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='r', facecolor='none'))
ax.text(x1, y1, f'{results.names[int(cls)]} {conf:.2f}', color='r', fontsize=12, verticalalignment='top')
st.pyplot(fig)
# Função para execução completa da aplicação
if __name__ == "__main__":
run_app()
|