is-it-max / segmentation.py
paddeh's picture
visualise-segmentation (#1)
9073e25 verified
import torch
from torchvision import transforms, models
from torchvision.models.segmentation import deeplabv3_resnet101, DeepLabV3_ResNet101_Weights
import numpy as np
import cv2
import skimage.segmentation as seg
from PIL import Image, ImageDraw, ImageFont
dog_class = 12
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device {device} for segmentation")
# Load DeepLabV3 model for segmentation
print("Loading resnet101 segmentation model...")
seg_model = models.segmentation \
.deeplabv3_resnet101(weights=DeepLabV3_ResNet101_Weights.DEFAULT) \
.to(device) \
.eval()
def refine_dog_mask(mask, image):
# Merge all dog segments together
dog_mask = np.zeros_like(mask, dtype=np.uint8)
for class_id in np.unique(mask):
if class_id == 12: # Dog class
dog_mask[mask == class_id] = 1
# Apply morphological operations to connect fragmented segments
kernel = np.ones((15, 15), np.uint8)
dog_mask = cv2.morphologyEx(dog_mask, cv2.MORPH_CLOSE, kernel) # Close gaps
dog_mask = cv2.dilate(dog_mask, kernel, iterations=2) # Expand segmentation
# Refine mask using superpixel segmentation
segments = seg.slic(np.array(image), n_segments=100, compactness=10)
refined_dog_mask = np.where(dog_mask == 1, segments, 0)
# Restore the dog class label (12) in refined regions
refined_dog_mask[dog_mask == 1] = dog_class
# Restore the dog class label (12) in refined regions
mask[refined_dog_mask > 0] = dog_class
# Convert mask to np.uint8 if necessary
return mask.astype(np.uint8)
def crop_dog(image, mask, target_aspect=1, padding=20):
# Get bounding box of the dog
y_indices, x_indices = np.where(mask == dog_class) # Dog class pixels
if len(y_indices) == 0 or len(x_indices) == 0:
return image # No dog detected
x_min, x_max = x_indices.min(), x_indices.max()
y_min, y_max = y_indices.min(), y_indices.max()
# Calculate aspect ratio of resize target
width = x_max - x_min
height = y_max - y_min
current_aspect = width / height
# Adjust bounding box to match target aspect ratio
if current_aspect > target_aspect:
new_height = width / target_aspect
diff = (new_height - height) / 2
y_min = max(0, int(y_min - diff))
y_max = min(mask.shape[0], int(y_max + diff))
else:
new_width = height * target_aspect
diff = (new_width - width) / 2
x_min = max(0, int(x_min - diff))
x_max = min(mask.shape[1], int(x_max + diff))
# Apply padding
x_min = max(0, x_min - padding)
x_max = min(mask.shape[1], x_max + padding)
y_min = max(0, y_min - padding)
y_max = min(mask.shape[0], y_max + padding)
cropped_image = image.crop((x_min, y_min, x_max, y_max))
return cropped_image
def segment_image(image):
image = image.convert("RGB")
orig_size = image.size
transform = transforms.Compose([
transforms.ToTensor()
])
image_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
output = seg_model(image_tensor)['out'][0]
mask = output.argmax(0) # Keep on GPU
# Dynamically determine the main object class
unique_classes = mask.unique()
unique_classes = unique_classes[unique_classes != 0] # Remove background class (0)
if len(unique_classes) == 0:
print(f'No segmentation found')
return image, None # Skip image if no valid segmentation found
mask = mask.cpu().numpy() # Move to CPU only when needed
mask = refine_dog_mask(mask, image)
return image, mask
def visualize_segmentation(image, mask):
font_border = 2
font_size_segment_pct = 0.25
# Create color overlay for masks
overlay = np.zeros((*mask.shape, 3), dtype=np.uint8)
unique_classes = np.unique(mask)
contours_dict = []
for class_id in unique_classes:
if class_id == 0:
continue # Skip background
mask_indices = np.argwhere(mask == class_id)
if len(mask_indices) > 0:
mask_binary = (mask == class_id).astype(np.uint8)
contours, _ = cv2.findContours(mask_binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
if cv2.contourArea(contour) > 100: # Filter small segments
contours_dict.append((contour, class_id))
color = (0, 255, 0) if class_id == dog_class else (255, 0, 0) # Green for dog, red for others
cv2.drawContours(overlay, [contour], -1, color, thickness=cv2.FILLED)
# Convert overlay to PIL image with transparency
overlay_img = Image.fromarray(overlay).convert("RGBA")
image_rgba = image.convert("RGBA")
blended = Image.blend(image_rgba, overlay_img, alpha=0.3)
# Draw category ID inside masks
draw = ImageDraw.Draw(blended)
for contour, class_id in contours_dict:
x, y, w, h = cv2.boundingRect(contour)
font_size = max(10, int(h * font_size_segment_pct))
try:
font = ImageFont.truetype("arial.ttf", font_size)
except IOError:
font = ImageFont.load_default()
text_x = x + w // 2
text_y = y + h // 2
draw.text((text_x - font_border, text_y), str(class_id), fill=(0, 0, 0, 255), font=font)
draw.text((text_x + font_border, text_y), str(class_id), fill=(0, 0, 0, 255), font=font)
draw.text((text_x, text_y - font_border), str(class_id), fill=(0, 0, 0, 255), font=font)
draw.text((text_x, text_y + font_border), str(class_id), fill=(0, 0, 0, 255), font=font)
draw.text((text_x, text_y), str(class_id), fill=(255, 255, 255, 255), font=font)
return blended.convert("RGB")