bluspater commited on
Commit
dd7e0f9
·
verified ·
1 Parent(s): b6584d4

Update modnet_utils.py

Browse files
Files changed (1) hide show
  1. modnet_utils.py +3 -61
modnet_utils.py CHANGED
@@ -1,63 +1,5 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import numpy as np
5
- import cv2
6
  from PIL import Image
7
- from torchvision import transforms
8
 
9
-
10
- class MODNet(nn.Module):
11
- def __init__(self, backbone):
12
- super(MODNet, self).__init__()
13
- self.backbone = backbone
14
- self.seg_head = nn.Sequential(
15
- nn.Conv2d(1280, 64, kernel_size=3, padding=1), # changed from 320 to 1280
16
- nn.ReLU(),
17
- nn.Conv2d(64, 1, kernel_size=1),
18
- nn.Sigmoid()
19
- )
20
-
21
- def forward(self, x):
22
- features = self.backbone(x)
23
- pred_matte = self.seg_head(features)
24
- return pred_matte
25
-
26
-
27
- def preprocess_image(image: Image.Image, device: torch.device) -> torch.Tensor:
28
- transform = transforms.Compose([
29
- transforms.Resize((512, 512)),
30
- transforms.ToTensor(),
31
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
32
- std=[0.229, 0.224, 0.225])
33
- ])
34
- img_tensor = transform(image.convert("RGB")).unsqueeze(0).to(device)
35
- return img_tensor
36
-
37
-
38
- def remove_background_modnet(image: Image.Image) -> Image.Image:
39
- from torchvision.models.mobilenet import mobilenet_v2
40
-
41
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
42
-
43
- backbone = mobilenet_v2(pretrained=True).features
44
- modnet = MODNet(backbone)
45
- modnet.to(device)
46
-
47
- state_dict = torch.load('pretrained/modnet_webcam_portrait_matting.ckpt', map_location=device)
48
- modnet.load_state_dict(state_dict, strict=False)
49
- modnet.eval()
50
-
51
- img_tensor = preprocess_image(image, device)
52
-
53
- with torch.no_grad():
54
- pred_matte = modnet(img_tensor)
55
-
56
- matte = pred_matte[0][0].cpu().numpy()
57
- matte = cv2.resize(matte, image.size)
58
- matte = np.uint8(matte * 255)
59
-
60
- image = image.convert("RGBA")
61
- image_np = np.array(image)
62
- image_np[:, :, 3] = matte
63
- return Image.fromarray(image_np)
 
 
 
 
 
 
1
  from PIL import Image
 
2
 
3
+ def remove_background(image: Image.Image) -> Image.Image:
4
+ from modnet_model import remove_background_modnet
5
+ return remove_background_modnet(image)