File size: 5,767 Bytes
9073e25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import torch
from torchvision import transforms, models
from torchvision.models.segmentation import deeplabv3_resnet101, DeepLabV3_ResNet101_Weights
import numpy as np
import cv2
import skimage.segmentation as seg
from PIL import Image, ImageDraw, ImageFont

dog_class = 12

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device {device} for segmentation")

# Load DeepLabV3 model for segmentation
print("Loading resnet101 segmentation model...")
seg_model = models.segmentation \
    .deeplabv3_resnet101(weights=DeepLabV3_ResNet101_Weights.DEFAULT) \
    .to(device) \
    .eval()


def refine_dog_mask(mask, image):
    # Merge all dog segments together
    dog_mask = np.zeros_like(mask, dtype=np.uint8)
    for class_id in np.unique(mask):
        if class_id == 12:  # Dog class
            dog_mask[mask == class_id] = 1

    # Apply morphological operations to connect fragmented segments
    kernel = np.ones((15, 15), np.uint8)
    dog_mask = cv2.morphologyEx(dog_mask, cv2.MORPH_CLOSE, kernel)  # Close gaps
    dog_mask = cv2.dilate(dog_mask, kernel, iterations=2)  # Expand segmentation

    # Refine mask using superpixel segmentation
    segments = seg.slic(np.array(image), n_segments=100, compactness=10)
    refined_dog_mask = np.where(dog_mask == 1, segments, 0)

    # Restore the dog class label (12) in refined regions
    refined_dog_mask[dog_mask == 1] = dog_class

    # Restore the dog class label (12) in refined regions
    mask[refined_dog_mask > 0] = dog_class

    # Convert mask to np.uint8 if necessary
    return mask.astype(np.uint8)


def crop_dog(image, mask, target_aspect=1, padding=20):
    # Get bounding box of the dog
    y_indices, x_indices = np.where(mask == dog_class)  # Dog class pixels
    if len(y_indices) == 0 or len(x_indices) == 0:
        return image  # No dog detected

    x_min, x_max = x_indices.min(), x_indices.max()
    y_min, y_max = y_indices.min(), y_indices.max()

    # Calculate aspect ratio of resize target
    width = x_max - x_min
    height = y_max - y_min
    current_aspect = width / height

    # Adjust bounding box to match target aspect ratio
    if current_aspect > target_aspect:
        new_height = width / target_aspect
        diff = (new_height - height) / 2
        y_min = max(0, int(y_min - diff))
        y_max = min(mask.shape[0], int(y_max + diff))
    else:
        new_width = height * target_aspect
        diff = (new_width - width) / 2
        x_min = max(0, int(x_min - diff))
        x_max = min(mask.shape[1], int(x_max + diff))

    # Apply padding
    x_min = max(0, x_min - padding)
    x_max = min(mask.shape[1], x_max + padding)
    y_min = max(0, y_min - padding)
    y_max = min(mask.shape[0], y_max + padding)

    cropped_image = image.crop((x_min, y_min, x_max, y_max))
    return cropped_image


def segment_image(image):
    image = image.convert("RGB")
    orig_size = image.size
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    image_tensor = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        output = seg_model(image_tensor)['out'][0]
    mask = output.argmax(0)  # Keep on GPU

    # Dynamically determine the main object class
    unique_classes = mask.unique()
    unique_classes = unique_classes[unique_classes != 0]  # Remove background class (0)
    if len(unique_classes) == 0:
        print(f'No segmentation found')
        return image, None  # Skip image if no valid segmentation found

    mask = mask.cpu().numpy()  # Move to CPU only when needed
    mask = refine_dog_mask(mask, image)

    return image, mask


def visualize_segmentation(image, mask):
    font_border = 2
    font_size_segment_pct = 0.25

    # Create color overlay for masks
    overlay = np.zeros((*mask.shape, 3), dtype=np.uint8)
    unique_classes = np.unique(mask)
    contours_dict = []

    for class_id in unique_classes:
        if class_id == 0:
            continue  # Skip background
        mask_indices = np.argwhere(mask == class_id)
        if len(mask_indices) > 0:
            mask_binary = (mask == class_id).astype(np.uint8)
            contours, _ = cv2.findContours(mask_binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            for contour in contours:
                if cv2.contourArea(contour) > 100:  # Filter small segments
                    contours_dict.append((contour, class_id))
                    color = (0, 255, 0) if class_id == dog_class else (255, 0, 0)  # Green for dog, red for others
                    cv2.drawContours(overlay, [contour], -1, color, thickness=cv2.FILLED)

    # Convert overlay to PIL image with transparency
    overlay_img = Image.fromarray(overlay).convert("RGBA")
    image_rgba = image.convert("RGBA")
    blended = Image.blend(image_rgba, overlay_img, alpha=0.3)

    # Draw category ID inside masks
    draw = ImageDraw.Draw(blended)
    for contour, class_id in contours_dict:
        x, y, w, h = cv2.boundingRect(contour)
        font_size = max(10, int(h * font_size_segment_pct))

        try:
            font = ImageFont.truetype("arial.ttf", font_size)
        except IOError:
            font = ImageFont.load_default()

        text_x = x + w // 2
        text_y = y + h // 2

        draw.text((text_x - font_border, text_y), str(class_id), fill=(0, 0, 0, 255), font=font)
        draw.text((text_x + font_border, text_y), str(class_id), fill=(0, 0, 0, 255), font=font)
        draw.text((text_x, text_y - font_border), str(class_id), fill=(0, 0, 0, 255), font=font)
        draw.text((text_x, text_y + font_border), str(class_id), fill=(0, 0, 0, 255), font=font)
        draw.text((text_x, text_y), str(class_id), fill=(255, 255, 255, 255), font=font)

    return blended.convert("RGB")