bluspater commited on
Commit
b6584d4
·
verified ·
1 Parent(s): e69a94a

Update model/modnet.py

Browse files
Files changed (1) hide show
  1. model/modnet.py +65 -58
model/modnet.py CHANGED
@@ -1,48 +1,73 @@
1
  import torch
2
- import cv2
 
3
  import numpy as np
 
4
  from PIL import Image
5
- from torchvision.models.mobilenetv2 import mobilenet_v2
6
- import torch.nn as nn
7
-
8
-
9
- def clean_state_dict(state_dict):
10
- """Remove 'module.' prefix if present in keys."""
11
- new_state_dict = {}
12
- for k, v in state_dict.items():
13
- if k.startswith('module.'):
14
- new_state_dict[k[7:]] = v
15
- else:
16
- new_state_dict[k] = v
17
- return new_state_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  class MODNet(nn.Module):
21
- def __init__(self, in_channels=3, hr_channels=32, backbone_pretrained=True):
22
  super(MODNet, self).__init__()
23
-
24
- mobilenet = mobilenet_v2(pretrained=backbone_pretrained)
25
- self.backbone = mobilenet.features # nn.Sequential already
26
-
27
- # Simulate enc_channels expected by MODNet-style branches
28
- self.enc_channels = [24, 32, 96, 320]
29
-
30
- # Dummy branches to satisfy loading
31
- self.lr_branch = nn.Identity()
32
- self.hr_branch = nn.Identity()
33
- self.f_branch = nn.Identity()
34
-
35
- def forward(self, x, inference=False):
36
  features = self.backbone(x)
37
- return features, features, features # Dummy outputs
 
38
 
39
 
40
  def preprocess_image(image: Image.Image, device: torch.device) -> torch.Tensor:
41
- img = np.array(image.convert("RGB"))
42
- img_resized = cv2.resize(img, (512, 512))
43
- img_input = img_resized / 255.0
44
- img_input = torch.tensor(img_input, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(device)
45
- return img_input
 
 
 
46
 
47
 
48
  def remove_background_modnet(image: Image.Image) -> Image.Image:
@@ -51,37 +76,19 @@ def remove_background_modnet(image: Image.Image) -> Image.Image:
51
  modnet = MODNet()
52
  modnet.to(device)
53
 
54
- # Load weights
55
- state_dict = torch.load('pretrained/modnet_webcam_portrait_matting.ckpt', map_location=device)
56
- modnet.load_state_dict(clean_state_dict(state_dict), strict=False)
57
-
58
  modnet.eval()
59
 
60
- img_input = preprocess_image(image, device)
61
 
62
  with torch.no_grad():
63
- output = modnet(img_input, True)
64
-
65
- if output is None:
66
- raise RuntimeError("MODNet returned None. Ensure model is correctly initialized and forward method is implemented.")
67
- if not isinstance(output, (tuple, list)):
68
- raise TypeError(f"MODNet output must be a list or tuple, got {type(output)}")
69
- if len(output) < 3:
70
- raise ValueError(f"Expected at least 3 outputs from MODNet, got {len(output)}")
71
-
72
- pred_semantic, pred_detail, pred_matte = output
73
-
74
- if pred_matte is None:
75
- raise RuntimeError("pred_matte is None — MODNet forward method may not be returning expected outputs.")
76
 
77
  matte = pred_matte[0][0].cpu().numpy()
78
  matte = cv2.resize(matte, image.size)
79
  matte = np.uint8(matte * 255)
80
 
81
- rgba_image = image.convert("RGBA")
82
- image_np = np.array(rgba_image)
83
- if image_np.shape[2] < 4:
84
- alpha_channel = 255 * np.ones((*image_np.shape[:2], 1), dtype=np.uint8)
85
- image_np = np.concatenate([image_np, alpha_channel], axis=2)
86
  image_np[:, :, 3] = matte
87
- return Image.fromarray(image_np)
 
1
  import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
  import numpy as np
5
+ import cv2
6
  from PIL import Image
7
+ from torchvision import transforms
8
+
9
+
10
+ # Backbone: U2NET-like architecture (simplified for inference only)
11
+ class BasicConvBlock(nn.Module):
12
+ def __init__(self, in_channels, out_channels):
13
+ super(BasicConvBlock, self).__init__()
14
+ self.block = nn.Sequential(
15
+ nn.Conv2d(in_channels, out_channels, 3, 1, 1),
16
+ nn.BatchNorm2d(out_channels),
17
+ nn.ReLU(inplace=True),
18
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1),
19
+ nn.BatchNorm2d(out_channels),
20
+ nn.ReLU(inplace=True),
21
+ )
22
+
23
+ def forward(self, x):
24
+ return self.block(x)
25
+
26
+
27
+ class SimpleMODNetBackbone(nn.Module):
28
+ def __init__(self):
29
+ super(SimpleMODNetBackbone, self).__init__()
30
+ self.stage1 = BasicConvBlock(3, 64)
31
+ self.pool1 = nn.MaxPool2d(2, 2)
32
+ self.stage2 = BasicConvBlock(64, 128)
33
+ self.pool2 = nn.MaxPool2d(2, 2)
34
+ self.stage3 = BasicConvBlock(128, 256)
35
+
36
+ def forward(self, x):
37
+ x = self.stage1(x)
38
+ x = self.pool1(x)
39
+ x = self.stage2(x)
40
+ x = self.pool2(x)
41
+ x = self.stage3(x)
42
+ return x
43
 
44
 
45
  class MODNet(nn.Module):
46
+ def __init__(self):
47
  super(MODNet, self).__init__()
48
+ self.backbone = SimpleMODNetBackbone()
49
+ self.seg_head = nn.Sequential(
50
+ nn.Conv2d(256, 64, kernel_size=3, padding=1),
51
+ nn.ReLU(),
52
+ nn.Conv2d(64, 1, kernel_size=1),
53
+ nn.Sigmoid()
54
+ )
55
+
56
+ def forward(self, x):
 
 
 
 
57
  features = self.backbone(x)
58
+ pred_matte = self.seg_head(features)
59
+ return pred_matte
60
 
61
 
62
  def preprocess_image(image: Image.Image, device: torch.device) -> torch.Tensor:
63
+ transform = transforms.Compose([
64
+ transforms.Resize((512, 512)),
65
+ transforms.ToTensor(),
66
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
67
+ std=[0.229, 0.224, 0.225])
68
+ ])
69
+ img_tensor = transform(image.convert("RGB")).unsqueeze(0).to(device)
70
+ return img_tensor
71
 
72
 
73
  def remove_background_modnet(image: Image.Image) -> Image.Image:
 
76
  modnet = MODNet()
77
  modnet.to(device)
78
 
79
+ # Skip loading weights (simple version)
 
 
 
80
  modnet.eval()
81
 
82
+ img_tensor = preprocess_image(image, device)
83
 
84
  with torch.no_grad():
85
+ pred_matte = modnet(img_tensor)
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  matte = pred_matte[0][0].cpu().numpy()
88
  matte = cv2.resize(matte, image.size)
89
  matte = np.uint8(matte * 255)
90
 
91
+ image = image.convert("RGBA")
92
+ image_np = np.array(image)
 
 
 
93
  image_np[:, :, 3] = matte
94
+ return Image.fromarray(image_np)