Spaces:
Sleeping
Sleeping
| import torch | |
| from torchvision import transforms, models | |
| from torchvision.models.segmentation import deeplabv3_resnet101, DeepLabV3_ResNet101_Weights | |
| import numpy as np | |
| import cv2 | |
| import skimage.segmentation as seg | |
| from PIL import Image, ImageDraw, ImageFont | |
| dog_class = 12 | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Using device {device} for segmentation") | |
| # Load DeepLabV3 model for segmentation | |
| print("Loading resnet101 segmentation model...") | |
| seg_model = models.segmentation \ | |
| .deeplabv3_resnet101(weights=DeepLabV3_ResNet101_Weights.DEFAULT) \ | |
| .to(device) \ | |
| .eval() | |
| def refine_dog_mask(mask, image): | |
| # Merge all dog segments together | |
| dog_mask = np.zeros_like(mask, dtype=np.uint8) | |
| for class_id in np.unique(mask): | |
| if class_id == 12: # Dog class | |
| dog_mask[mask == class_id] = 1 | |
| # Apply morphological operations to connect fragmented segments | |
| kernel = np.ones((15, 15), np.uint8) | |
| dog_mask = cv2.morphologyEx(dog_mask, cv2.MORPH_CLOSE, kernel) # Close gaps | |
| dog_mask = cv2.dilate(dog_mask, kernel, iterations=2) # Expand segmentation | |
| # Refine mask using superpixel segmentation | |
| segments = seg.slic(np.array(image), n_segments=100, compactness=10) | |
| refined_dog_mask = np.where(dog_mask == 1, segments, 0) | |
| # Restore the dog class label (12) in refined regions | |
| refined_dog_mask[dog_mask == 1] = dog_class | |
| # Restore the dog class label (12) in refined regions | |
| mask[refined_dog_mask > 0] = dog_class | |
| # Convert mask to np.uint8 if necessary | |
| return mask.astype(np.uint8) | |
| def crop_dog(image, mask, target_aspect=1, padding=20): | |
| # Get bounding box of the dog | |
| y_indices, x_indices = np.where(mask == dog_class) # Dog class pixels | |
| if len(y_indices) == 0 or len(x_indices) == 0: | |
| return image # No dog detected | |
| x_min, x_max = x_indices.min(), x_indices.max() | |
| y_min, y_max = y_indices.min(), y_indices.max() | |
| # Calculate aspect ratio of resize target | |
| width = x_max - x_min | |
| height = y_max - y_min | |
| current_aspect = width / height | |
| # Adjust bounding box to match target aspect ratio | |
| if current_aspect > target_aspect: | |
| new_height = width / target_aspect | |
| diff = (new_height - height) / 2 | |
| y_min = max(0, int(y_min - diff)) | |
| y_max = min(mask.shape[0], int(y_max + diff)) | |
| else: | |
| new_width = height * target_aspect | |
| diff = (new_width - width) / 2 | |
| x_min = max(0, int(x_min - diff)) | |
| x_max = min(mask.shape[1], int(x_max + diff)) | |
| # Apply padding | |
| x_min = max(0, x_min - padding) | |
| x_max = min(mask.shape[1], x_max + padding) | |
| y_min = max(0, y_min - padding) | |
| y_max = min(mask.shape[0], y_max + padding) | |
| cropped_image = image.crop((x_min, y_min, x_max, y_max)) | |
| return cropped_image | |
| def segment_image(image): | |
| image = image.convert("RGB") | |
| orig_size = image.size | |
| transform = transforms.Compose([ | |
| transforms.ToTensor() | |
| ]) | |
| image_tensor = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| output = seg_model(image_tensor)['out'][0] | |
| mask = output.argmax(0) # Keep on GPU | |
| # Dynamically determine the main object class | |
| unique_classes = mask.unique() | |
| unique_classes = unique_classes[unique_classes != 0] # Remove background class (0) | |
| if len(unique_classes) == 0: | |
| print(f'No segmentation found') | |
| return image, None # Skip image if no valid segmentation found | |
| mask = mask.cpu().numpy() # Move to CPU only when needed | |
| mask = refine_dog_mask(mask, image) | |
| return image, mask | |
| def visualize_segmentation(image, mask): | |
| font_border = 2 | |
| font_size_segment_pct = 0.25 | |
| # Create color overlay for masks | |
| overlay = np.zeros((*mask.shape, 3), dtype=np.uint8) | |
| unique_classes = np.unique(mask) | |
| contours_dict = [] | |
| for class_id in unique_classes: | |
| if class_id == 0: | |
| continue # Skip background | |
| mask_indices = np.argwhere(mask == class_id) | |
| if len(mask_indices) > 0: | |
| mask_binary = (mask == class_id).astype(np.uint8) | |
| contours, _ = cv2.findContours(mask_binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| for contour in contours: | |
| if cv2.contourArea(contour) > 100: # Filter small segments | |
| contours_dict.append((contour, class_id)) | |
| color = (0, 255, 0) if class_id == dog_class else (255, 0, 0) # Green for dog, red for others | |
| cv2.drawContours(overlay, [contour], -1, color, thickness=cv2.FILLED) | |
| # Convert overlay to PIL image with transparency | |
| overlay_img = Image.fromarray(overlay).convert("RGBA") | |
| image_rgba = image.convert("RGBA") | |
| blended = Image.blend(image_rgba, overlay_img, alpha=0.3) | |
| # Draw category ID inside masks | |
| draw = ImageDraw.Draw(blended) | |
| for contour, class_id in contours_dict: | |
| x, y, w, h = cv2.boundingRect(contour) | |
| font_size = max(10, int(h * font_size_segment_pct)) | |
| try: | |
| font = ImageFont.truetype("arial.ttf", font_size) | |
| except IOError: | |
| font = ImageFont.load_default() | |
| text_x = x + w // 2 | |
| text_y = y + h // 2 | |
| draw.text((text_x - font_border, text_y), str(class_id), fill=(0, 0, 0, 255), font=font) | |
| draw.text((text_x + font_border, text_y), str(class_id), fill=(0, 0, 0, 255), font=font) | |
| draw.text((text_x, text_y - font_border), str(class_id), fill=(0, 0, 0, 255), font=font) | |
| draw.text((text_x, text_y + font_border), str(class_id), fill=(0, 0, 0, 255), font=font) | |
| draw.text((text_x, text_y), str(class_id), fill=(255, 255, 255, 255), font=font) | |
| return blended.convert("RGB") | |