File size: 5,800 Bytes
5177584
 
 
 
 
 
f9739d9
5177584
f9739d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5177584
f9739d9
 
 
 
 
5177584
 
 
f9739d9
 
 
 
5177584
 
f9739d9
5177584
 
 
 
 
 
 
 
 
 
 
f9739d9
 
 
 
 
 
 
5177584
f9739d9
 
 
 
 
 
 
5177584
f9739d9
 
5177584
f9739d9
 
5177584
f9739d9
5177584
f9739d9
 
 
5177584
f9739d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5177584
 
 
 
 
 
 
 
f9739d9
 
5177584
 
 
f9739d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5177584
f9739d9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import streamlit as st
from PIL import Image
import torch
import numpy as np
import matplotlib.pyplot as plt
from io import BytesIO
import time

# Funções principais
def load_model(model_name: str):
    """
    Função para carregar o modelo escolhido.
    """
    if model_name == "YOLOv5":
        return torch.hub.load('ultralytics/yolov5', 'yolov5s')  # YOLOv5 light
    elif model_name == "Faster R-CNN":
        return torch.hub.load('pytorch/vision', 'fasterrcnn_resnet50_fpn', pretrained=True)
    elif model_name == "RetinaNet":
        return torch.hub.load('facebookresearch/detectron2', 'retinanet_r50_fpn', pretrained=True)
    else:
        raise ValueError("Modelo não suportado!")

def detect_objects(model, image: Image, model_name: str):
    """
    Função que executa a detecção de objetos com o modelo escolhido.
    """
    img_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0).float()
    
    with torch.no_grad():
        if model_name == "YOLOv5":
            results = model(img_tensor)
        else:
            results = model([img_tensor])[0]
    
    return results

def save_image_with_annotations(image: Image, results, model_name: str):
    """
    Função para salvar a imagem com anotações das detecções feitas.
    """
    fig, ax = plt.subplots(1, 1, figsize=(12, 9))
    ax.imshow(np.array(image))

    for *box, conf, cls in results.xywh[0]:
        x1, y1, x2, y2 = box
        ax.add_patch(plt.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='r', facecolor='none'))
        ax.text(x1, y1, f'{results.names[int(cls)]} {conf:.2f}', color='r', fontsize=12, verticalalignment='top')

    output = BytesIO()
    plt.savefig(output, format='PNG')
    output.seek(0)
    
    return output

def display_feedback(results):
    """
    Exibe um feedback dinâmico com informações sobre as detecções.
    """
    st.write(f"🔍 Objetos Detectados: {len(results.xywh[0])}")
    st.write(f"🔢 Classes: {results.names}")
    st.write(f"📊 Confiança: {results.xywh[0][:, 4]}")

# Função de inicialização
def initialize_app():
    """
    Função para inicializar a aplicação Streamlit.
    """
    st.set_page_config(page_title="Detecção de Objetos com IA", page_icon="🤖", layout="centered")
    st.title("🧠 Elias Andrade - Detecção de Objetos com IA 🤖")

    # Seleção de modelo
    model_choice = st.selectbox("Escolha o modelo de detecção", ["YOLOv5", "Faster R-CNN", "RetinaNet"])

    # Upload de imagem
    uploaded_file = st.file_uploader("📂 Carregue uma imagem para detecção", type=["jpg", "jpeg", "png"])

    return model_choice, uploaded_file

# Função principal para execução da aplicação
def run_app():
    model_choice, uploaded_file = initialize_app()

    # Caso o usuário carregue uma imagem
    if uploaded_file is not None:
        image = Image.open(uploaded_file)
        st.image(image, caption="Imagem Carregada", use_column_width=True)

        # Seleção de modelo
        model = load_model(model_choice)

        start_button = st.button("Iniciar Detecção 🔍")
        save_button = st.button("Salvar Imagem com Marcação 💾")

        if start_button:
            with st.spinner("Detectando objetos..."):
                results = detect_objects(model, image, model_choice)
            st.image(results.render()[0], caption="Imagem com Detecção", use_column_width=True)
            display_feedback(results)

        if save_button:
            output_image = save_image_with_annotations(image, results, model_choice)
            st.download_button(
                label="Baixar Imagem com Marcação 📥",
                data=output_image,
                file_name="imagem_com_marcacao.png",
                mime="image/png"
            )

# Função para a escolha de objetos positivos e negativos
def mark_objects(results):
    """
    Função para permitir ao usuário marcar objetos como positivos ou negativos.
    """
    positive_objects = st.multiselect(
        "Escolha objetos para marcar como positivos ✅",
        options=results.names, default=[]
    )
    negative_objects = st.multiselect(
        "Escolha objetos para marcar como negativos ❌",
        options=results.names, default=[]
    )

    if positive_objects or negative_objects:
        st.write("📝 Objetos Marcados:")
        st.write(f"Positivos: {positive_objects}")
        st.write(f"Negativos: {negative_objects}")
    else:
        st.write("🔘 Não há objetos marcados ainda.")

# Função para mostrar opções adicionais ao usuário
def show_advanced_options(results):
    """
    Função para exibir opções avançadas de manipulação e configuração do modelo.
    """
    st.write("🔧 Opções Avançadas")
    
    # Opção de permitir que o usuário ajuste a confiança mínima de detecção
    confidence_threshold = st.slider(
        "Defina o limiar de confiança para detectar objetos",
        min_value=0.0, max_value=1.0, value=0.5, step=0.05
    )
    
    filtered_results = [r for r in results.xywh[0] if r[4] > confidence_threshold]
    st.write(f"🔎 Detecções filtradas: {len(filtered_results)}")

    # Exibir a imagem com detecções filtradas
    if len(filtered_results) > 0:
        fig, ax = plt.subplots(1, 1, figsize=(12, 9))
        ax.imshow(np.array(image))
        
        for *box, conf, cls in filtered_results:
            x1, y1, x2, y2 = box
            ax.add_patch(plt.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='r', facecolor='none'))
            ax.text(x1, y1, f'{results.names[int(cls)]} {conf:.2f}', color='r', fontsize=12, verticalalignment='top')
        
        st.pyplot(fig)

# Função para execução completa da aplicação
if __name__ == "__main__":
    run_app()