import torch from torchvision.models.detection import maskrcnn_resnet50_fpn from torchvision.transforms import functional as F import numpy as np from PIL import Image # Load the pre-trained Mask R-CNN model def load_model(): model = maskrcnn_resnet50_fpn(pretrained=True) model.eval() return model # Get the mask for the person class def extract_person_mask(model, image_pil, score_threshold=0.8): image_tensor = F.to_tensor(image_pil) with torch.no_grad(): predictions = model([image_tensor])[0] for i, label in enumerate(predictions['labels']): if label.item() == 1 and predictions['scores'][i].item() > score_threshold: mask = predictions['masks'][i, 0].cpu().numpy() mask = (mask > 0.5).astype(np.uint8) * 255 return mask return None # Apply the mask to the image and convert to transparent PNG def apply_mask_to_image(image_pil, mask): image_rgba = image_pil.convert("RGBA") image_np = np.array(image_rgba) image_np[:, :, 3] = mask return Image.fromarray(image_np) # Save the image def save_segmented_person(output_image, output_path): output_image.save(output_path) print(f"Segmented person saved to: {output_path}") # Main function to run everything def segment_person(image_pil, output_path=""): model = load_model() mask = extract_person_mask(model, image_pil) if mask is not None: segmented_image = apply_mask_to_image(image_pil, mask) if output_path: save_segmented_person(segmented_image, output_path) return segmented_image else: print("No person detected with high enough confidence.") return None # Example usage if __name__ == "__main__": input_image_path = "./person/person1.jpg" output_image_path = "segmented_person.png" image = Image.open(input_image_path).convert("RGB") segment_person(image, output_image_path)