Spaces:
Sleeping
Sleeping
File size: 1,931 Bytes
f74ae4b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 | import torch
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.transforms import functional as F
import numpy as np
from PIL import Image
# Load the pre-trained Mask R-CNN model
def load_model():
model = maskrcnn_resnet50_fpn(pretrained=True)
model.eval()
return model
# Get the mask for the person class
def extract_person_mask(model, image_pil, score_threshold=0.8):
image_tensor = F.to_tensor(image_pil)
with torch.no_grad():
predictions = model([image_tensor])[0]
for i, label in enumerate(predictions['labels']):
if label.item() == 1 and predictions['scores'][i].item() > score_threshold:
mask = predictions['masks'][i, 0].cpu().numpy()
mask = (mask > 0.5).astype(np.uint8) * 255
return mask
return None
# Apply the mask to the image and convert to transparent PNG
def apply_mask_to_image(image_pil, mask):
image_rgba = image_pil.convert("RGBA")
image_np = np.array(image_rgba)
image_np[:, :, 3] = mask
return Image.fromarray(image_np)
# Save the image
def save_segmented_person(output_image, output_path):
output_image.save(output_path)
print(f"Segmented person saved to: {output_path}")
# Main function to run everything
def segment_person(image_pil, output_path=""):
model = load_model()
mask = extract_person_mask(model, image_pil)
if mask is not None:
segmented_image = apply_mask_to_image(image_pil, mask)
if output_path:
save_segmented_person(segmented_image, output_path)
return segmented_image
else:
print("No person detected with high enough confidence.")
return None
# Example usage
if __name__ == "__main__":
input_image_path = "./person/person1.jpg"
output_image_path = "segmented_person.png"
image = Image.open(input_image_path).convert("RGB")
segment_person(image, output_image_path)
|