File size: 1,931 Bytes
f74ae4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.transforms import functional as F
import numpy as np
from PIL import Image

# Load the pre-trained Mask R-CNN model
def load_model():
    model = maskrcnn_resnet50_fpn(pretrained=True)
    model.eval()
    return model

# Get the mask for the person class
def extract_person_mask(model, image_pil, score_threshold=0.8):
    image_tensor = F.to_tensor(image_pil)
    with torch.no_grad():
        predictions = model([image_tensor])[0]

    for i, label in enumerate(predictions['labels']):
        if label.item() == 1 and predictions['scores'][i].item() > score_threshold:
            mask = predictions['masks'][i, 0].cpu().numpy()
            mask = (mask > 0.5).astype(np.uint8) * 255
            return mask

    return None

# Apply the mask to the image and convert to transparent PNG
def apply_mask_to_image(image_pil, mask):
    image_rgba = image_pil.convert("RGBA")
    image_np = np.array(image_rgba)
    image_np[:, :, 3] = mask
    return Image.fromarray(image_np)

# Save the image
def save_segmented_person(output_image, output_path):
    output_image.save(output_path)
    print(f"Segmented person saved to: {output_path}")

# Main function to run everything
def segment_person(image_pil, output_path=""):
    model = load_model()
    mask = extract_person_mask(model, image_pil)

    if mask is not None:
        segmented_image = apply_mask_to_image(image_pil, mask)
        if output_path:
            save_segmented_person(segmented_image, output_path)
        return segmented_image
    else:
        print("No person detected with high enough confidence.")
        return None

# Example usage
if __name__ == "__main__":
    input_image_path = "./person/person1.jpg"
    output_image_path = "segmented_person.png"
    image = Image.open(input_image_path).convert("RGB")
    segment_person(image, output_image_path)