import torch from torchvision import transforms, models from torchvision.models.segmentation import deeplabv3_resnet101, DeepLabV3_ResNet101_Weights import numpy as np import cv2 import skimage.segmentation as seg from PIL import Image, ImageDraw, ImageFont dog_class = 12 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device {device} for segmentation") # Load DeepLabV3 model for segmentation print("Loading resnet101 segmentation model...") seg_model = models.segmentation \ .deeplabv3_resnet101(weights=DeepLabV3_ResNet101_Weights.DEFAULT) \ .to(device) \ .eval() def refine_dog_mask(mask, image): # Merge all dog segments together dog_mask = np.zeros_like(mask, dtype=np.uint8) for class_id in np.unique(mask): if class_id == 12: # Dog class dog_mask[mask == class_id] = 1 # Apply morphological operations to connect fragmented segments kernel = np.ones((15, 15), np.uint8) dog_mask = cv2.morphologyEx(dog_mask, cv2.MORPH_CLOSE, kernel) # Close gaps dog_mask = cv2.dilate(dog_mask, kernel, iterations=2) # Expand segmentation # Refine mask using superpixel segmentation segments = seg.slic(np.array(image), n_segments=100, compactness=10) refined_dog_mask = np.where(dog_mask == 1, segments, 0) # Restore the dog class label (12) in refined regions refined_dog_mask[dog_mask == 1] = dog_class # Restore the dog class label (12) in refined regions mask[refined_dog_mask > 0] = dog_class # Convert mask to np.uint8 if necessary return mask.astype(np.uint8) def crop_dog(image, mask, target_aspect=1, padding=20): # Get bounding box of the dog y_indices, x_indices = np.where(mask == dog_class) # Dog class pixels if len(y_indices) == 0 or len(x_indices) == 0: return image # No dog detected x_min, x_max = x_indices.min(), x_indices.max() y_min, y_max = y_indices.min(), y_indices.max() # Calculate aspect ratio of resize target width = x_max - x_min height = y_max - y_min current_aspect = width / height # Adjust bounding box to match target aspect ratio if current_aspect > target_aspect: new_height = width / target_aspect diff = (new_height - height) / 2 y_min = max(0, int(y_min - diff)) y_max = min(mask.shape[0], int(y_max + diff)) else: new_width = height * target_aspect diff = (new_width - width) / 2 x_min = max(0, int(x_min - diff)) x_max = min(mask.shape[1], int(x_max + diff)) # Apply padding x_min = max(0, x_min - padding) x_max = min(mask.shape[1], x_max + padding) y_min = max(0, y_min - padding) y_max = min(mask.shape[0], y_max + padding) cropped_image = image.crop((x_min, y_min, x_max, y_max)) return cropped_image def segment_image(image): image = image.convert("RGB") orig_size = image.size transform = transforms.Compose([ transforms.ToTensor() ]) image_tensor = transform(image).unsqueeze(0).to(device) with torch.no_grad(): output = seg_model(image_tensor)['out'][0] mask = output.argmax(0) # Keep on GPU # Dynamically determine the main object class unique_classes = mask.unique() unique_classes = unique_classes[unique_classes != 0] # Remove background class (0) if len(unique_classes) == 0: print(f'No segmentation found') return image, None # Skip image if no valid segmentation found mask = mask.cpu().numpy() # Move to CPU only when needed mask = refine_dog_mask(mask, image) return image, mask def visualize_segmentation(image, mask): font_border = 2 font_size_segment_pct = 0.25 # Create color overlay for masks overlay = np.zeros((*mask.shape, 3), dtype=np.uint8) unique_classes = np.unique(mask) contours_dict = [] for class_id in unique_classes: if class_id == 0: continue # Skip background mask_indices = np.argwhere(mask == class_id) if len(mask_indices) > 0: mask_binary = (mask == class_id).astype(np.uint8) contours, _ = cv2.findContours(mask_binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) for contour in contours: if cv2.contourArea(contour) > 100: # Filter small segments contours_dict.append((contour, class_id)) color = (0, 255, 0) if class_id == dog_class else (255, 0, 0) # Green for dog, red for others cv2.drawContours(overlay, [contour], -1, color, thickness=cv2.FILLED) # Convert overlay to PIL image with transparency overlay_img = Image.fromarray(overlay).convert("RGBA") image_rgba = image.convert("RGBA") blended = Image.blend(image_rgba, overlay_img, alpha=0.3) # Draw category ID inside masks draw = ImageDraw.Draw(blended) for contour, class_id in contours_dict: x, y, w, h = cv2.boundingRect(contour) font_size = max(10, int(h * font_size_segment_pct)) try: font = ImageFont.truetype("arial.ttf", font_size) except IOError: font = ImageFont.load_default() text_x = x + w // 2 text_y = y + h // 2 draw.text((text_x - font_border, text_y), str(class_id), fill=(0, 0, 0, 255), font=font) draw.text((text_x + font_border, text_y), str(class_id), fill=(0, 0, 0, 255), font=font) draw.text((text_x, text_y - font_border), str(class_id), fill=(0, 0, 0, 255), font=font) draw.text((text_x, text_y + font_border), str(class_id), fill=(0, 0, 0, 255), font=font) draw.text((text_x, text_y), str(class_id), fill=(255, 255, 255, 255), font=font) return blended.convert("RGB")