Spaces:
Sleeping
Sleeping
| import torch | |
| from torchvision.models.detection import maskrcnn_resnet50_fpn | |
| from torchvision.transforms import functional as F | |
| import numpy as np | |
| from PIL import Image | |
| # Load the pre-trained Mask R-CNN model | |
| def load_model(): | |
| model = maskrcnn_resnet50_fpn(pretrained=True) | |
| model.eval() | |
| return model | |
| # Get the mask for the person class | |
| def extract_person_mask(model, image_pil, score_threshold=0.8): | |
| image_tensor = F.to_tensor(image_pil) | |
| with torch.no_grad(): | |
| predictions = model([image_tensor])[0] | |
| for i, label in enumerate(predictions['labels']): | |
| if label.item() == 1 and predictions['scores'][i].item() > score_threshold: | |
| mask = predictions['masks'][i, 0].cpu().numpy() | |
| mask = (mask > 0.5).astype(np.uint8) * 255 | |
| return mask | |
| return None | |
| # Apply the mask to the image and convert to transparent PNG | |
| def apply_mask_to_image(image_pil, mask): | |
| image_rgba = image_pil.convert("RGBA") | |
| image_np = np.array(image_rgba) | |
| image_np[:, :, 3] = mask | |
| return Image.fromarray(image_np) | |
| # Save the image | |
| def save_segmented_person(output_image, output_path): | |
| output_image.save(output_path) | |
| print(f"Segmented person saved to: {output_path}") | |
| # Main function to run everything | |
| def segment_person(image_pil, output_path=""): | |
| model = load_model() | |
| mask = extract_person_mask(model, image_pil) | |
| if mask is not None: | |
| segmented_image = apply_mask_to_image(image_pil, mask) | |
| if output_path: | |
| save_segmented_person(segmented_image, output_path) | |
| return segmented_image | |
| else: | |
| print("No person detected with high enough confidence.") | |
| return None | |
| # Example usage | |
| if __name__ == "__main__": | |
| input_image_path = "./person/person1.jpg" | |
| output_image_path = "segmented_person.png" | |
| image = Image.open(input_image_path).convert("RGB") | |
| segment_person(image, output_image_path) | |