3D_Image_Composer / image_segmentation_mask_rcnn.py
gexu13's picture
Upload 16 files
f74ae4b verified
import torch
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.transforms import functional as F
import numpy as np
from PIL import Image
# Load the pre-trained Mask R-CNN model
def load_model():
model = maskrcnn_resnet50_fpn(pretrained=True)
model.eval()
return model
# Get the mask for the person class
def extract_person_mask(model, image_pil, score_threshold=0.8):
image_tensor = F.to_tensor(image_pil)
with torch.no_grad():
predictions = model([image_tensor])[0]
for i, label in enumerate(predictions['labels']):
if label.item() == 1 and predictions['scores'][i].item() > score_threshold:
mask = predictions['masks'][i, 0].cpu().numpy()
mask = (mask > 0.5).astype(np.uint8) * 255
return mask
return None
# Apply the mask to the image and convert to transparent PNG
def apply_mask_to_image(image_pil, mask):
image_rgba = image_pil.convert("RGBA")
image_np = np.array(image_rgba)
image_np[:, :, 3] = mask
return Image.fromarray(image_np)
# Save the image
def save_segmented_person(output_image, output_path):
output_image.save(output_path)
print(f"Segmented person saved to: {output_path}")
# Main function to run everything
def segment_person(image_pil, output_path=""):
model = load_model()
mask = extract_person_mask(model, image_pil)
if mask is not None:
segmented_image = apply_mask_to_image(image_pil, mask)
if output_path:
save_segmented_person(segmented_image, output_path)
return segmented_image
else:
print("No person detected with high enough confidence.")
return None
# Example usage
if __name__ == "__main__":
input_image_path = "./person/person1.jpg"
output_image_path = "segmented_person.png"
image = Image.open(input_image_path).convert("RGB")
segment_person(image, output_image_path)