| | import torch
|
| | from torchvision import models, transforms
|
| | from PIL import Image
|
| | import numpy as np
|
| |
|
| |
|
| | def segment_person(image_path):
|
| |
|
| | model = models.segmentation.deeplabv3_resnet101(pretrained=True).eval()
|
| |
|
| |
|
| | input_image = Image.open(image_path).convert("RGB")
|
| | preprocess = transforms.Compose(
|
| | [
|
| | transforms.ToTensor(),
|
| | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| | ]
|
| | )
|
| | input_tensor = preprocess(input_image).unsqueeze(0)
|
| |
|
| | with torch.no_grad():
|
| | output = model(input_tensor)["out"][0]
|
| | mask = output.argmax(0).byte().numpy()
|
| |
|
| |
|
| | segmented_image = np.array(input_image)
|
| | segmented_image = np.dstack([segmented_image, mask * 255])
|
| | return Image.fromarray(segmented_image)
|
| |
|