File size: 3,792 Bytes
8d9d84c
 
 
 
 
ec6147f
8d9d84c
ec6147f
 
 
 
 
 
 
 
 
 
 
8d9d84c
 
 
 
ec6147f
 
 
8d9d84c
ec6147f
 
8d9d84c
 
ec6147f
 
 
 
8d9d84c
ec6147f
 
 
 
 
 
0f63cdb
 
ec6147f
0f63cdb
ec6147f
 
 
 
 
 
0f63cdb
 
ec6147f
0f63cdb
ec6147f
0f63cdb
ec6147f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f63cdb
ec6147f
 
0f63cdb
ec6147f
 
 
 
0f63cdb
ec6147f
 
8d9d84c
0f63cdb
8d9d84c
 
 
 
ec6147f
 
8d9d84c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gradio as gr
import numpy as np
import tensorflow as tf
from PIL import Image
import cv2
from scipy.ndimage import binary_fill_holes

CLASS_COLORS = {
    1: (0, 255, 0),       # diecast (verde)
    2: (0, 0, 255),       # large_packaging (azul)
    3: (255, 0, 0),       # packaging (vermelho)
}
CLASS_NAMES = {
    1: "diecast",
    2: "large_packaging",
    3: "packaging",
}

MODEL_PATH = "segmentation_model.h5"
model = tf.keras.models.load_model(MODEL_PATH)

def predict_image(input_image):
    # Converte a imagem de PIL para NumPy (formato BGR para OpenCV)
    original_img_np = np.array(input_image.convert('RGB'))
    original_img_cv2 = cv2.cvtColor(original_img_np, cv2.COLOR_RGB2BGR)

    # Redimensiona para o tamanho do modelo (256, 256)
    img_resized = tf.image.resize(original_img_np, (256, 256))
    img_input = np.expand_dims(img_resized, axis=0)

    # Faz a previsão do modelo
    prediction = model.predict(img_input, verbose=0)
    
    # Obtém a máscara e a confiança para cada pixel
    mask_predicted = np.argmax(prediction[0], axis=-1)
    confidences = np.max(prediction[0], axis=-1)

    # Redimensiona a máscara e a confiança para o tamanho da imagem original
    original_size = original_img_np.shape[:2]
    mask_final = cv2.resize(mask_predicted.astype(np.uint8), (original_size[1], original_size[0]), interpolation=cv2.INTER_NEAREST)
    confidences_final = cv2.resize(confidences, (original_size[1], original_size[0]), interpolation=cv2.INTER_LINEAR)
    
    # Processa a imagem para desenhar as caixas
    final_img = original_img_cv2.copy()
    confidence_threshold = 0.8  # Limiar de 80%

    for class_id in np.unique(mask_final):
        if class_id == 0:
            continue
        
        class_name = CLASS_NAMES.get(class_id, f"Classe {class_id}")
        
        binary_mask = (mask_final == class_id)
        binary_mask = binary_fill_holes(binary_mask)
        binary_mask_uint8 = binary_mask.astype(np.uint8)
        
        contours, _ = cv2.findContours(binary_mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        
        for contour in contours:
            x_min, y_min, w, h = cv2.boundingRect(contour)
            x_max, y_max = x_min + w, y_min + h
            
            region_mask = (mask_final[y_min:y_max, x_min:x_max] == class_id)
            region_confidences = confidences_final[y_min:y_max, x_min:x_max][region_mask]
            
            if region_confidences.size > 0:
                avg_confidence = np.mean(region_confidences)
            else:
                avg_confidence = 0
            
            if avg_confidence > confidence_threshold:
                label_text = f"{class_name}: {avg_confidence:.2f}%"
                color_tuple = CLASS_COLORS.get(class_id, (255, 255, 255))
                
                # Desenha a caixa
                cv2.rectangle(final_img, (x_min, y_min), (x_max, y_max), color_tuple, 2)
                
                # Desenha o fundo do texto
                (text_width, text_height), baseline = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
                cv2.rectangle(final_img, (x_min, y_min - text_height - 10), (x_min + text_width, y_min), color_tuple, -1)
                cv2.putText(final_img, label_text, (x_min, y_min - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 2)
    
    # Converte de volta para PIL e retorna
    final_img_rgb = cv2.cvtColor(final_img, cv2.COLOR_BGR2RGB)
    return Image.fromarray(final_img_rgb)

# Define a interface Gradio
gr.Interface(
    fn=predict_image,
    inputs=gr.Image(type="pil"),
    outputs="image",
    title="Ferramenta de Segmentação e Detecção de Objetos",
    description="Carregue uma imagem e o modelo irá detetar objetos com caixas e confiança."
).launch()