bluspater commited on
Commit
4894062
·
verified ·
1 Parent(s): b5b8418

Create modnet_model.py

Browse files
Files changed (1) hide show
  1. model/modnet_model.py +27 -0
model/modnet_model.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ from torchvision import transforms
4
+ from model.modnet import MODNet # или как у тебя организована модель
5
+ import cv2
6
+ import numpy as np
7
+
8
+ # Предобученная модель
9
+ modnet = MODNet()
10
+ modnet.load_state_dict(torch.load('pretrained/modnet_webcam_portrait_matting.ckpt', map_location='cpu'))
11
+ modnet.eval()
12
+
13
+ def remove_background_modnet(image: Image.Image) -> Image.Image:
14
+ image = image.convert('RGB').resize((512, 512))
15
+ img_np = np.array(image).astype(np.uint8)
16
+ img_tensor = transforms.ToTensor()(img_np).unsqueeze(0)
17
+
18
+ with torch.no_grad():
19
+ _, _, matte = modnet(img_tensor)
20
+
21
+ matte_np = matte.squeeze().cpu().numpy()
22
+ matte_np = (matte_np * 255).astype(np.uint8)
23
+
24
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGRA)
25
+ img_np[:, :, 3] = matte_np
26
+
27
+ return Image.fromarray(img_np)