bluspater commited on
Commit
9d88509
·
verified ·
1 Parent(s): d40dd7d

Update model/modnet.py

Browse files
Files changed (1) hide show
  1. model/modnet.py +26 -14
model/modnet.py CHANGED
@@ -6,14 +6,32 @@ from torchvision.models.mobilenetv2 import mobilenet_v2
6
  import torch.nn as nn
7
 
8
 
 
 
 
 
 
 
 
 
 
 
 
9
  class MODNet(nn.Module):
10
  def __init__(self, in_channels=3, hr_channels=32, backbone_arch='mobilenetv2', backbone_pretrained=True):
11
  super(MODNet, self).__init__()
12
- self.backbone = nn.Identity() # Replaced with identity for fast test
 
 
 
13
 
14
  def forward(self, x, inference=False):
15
- # Fast dummy forward to avoid GPU/memory bottlenecks
16
- return x, x, x # dummy semantic, detail, matte
 
 
 
 
17
 
18
  def preprocess_image(image: Image.Image, device: torch.device) -> torch.Tensor:
19
  img = np.array(image.convert("RGB"))
@@ -23,22 +41,16 @@ def preprocess_image(image: Image.Image, device: torch.device) -> torch.Tensor:
23
  return img_input
24
 
25
 
26
- def clean_state_dict(state_dict):
27
- """Remove 'module.' prefix if present in keys."""
28
- new_state_dict = {}
29
- for k, v in state_dict.items():
30
- if k.startswith('module.'):
31
- new_state_dict[k[7:]] = v
32
- else:
33
- new_state_dict[k] = v
34
- return new_state_dict
35
-
36
-
37
  def remove_background_modnet(image: Image.Image) -> Image.Image:
38
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39
 
40
  modnet = MODNet()
41
  modnet.to(device)
 
 
 
 
 
42
  modnet.eval()
43
 
44
  img_input = preprocess_image(image, device)
 
6
  import torch.nn as nn
7
 
8
 
9
+ def clean_state_dict(state_dict):
10
+ """Remove 'module.' prefix if present in keys."""
11
+ new_state_dict = {}
12
+ for k, v in state_dict.items():
13
+ if k.startswith('module.'):
14
+ new_state_dict[k[7:]] = v
15
+ else:
16
+ new_state_dict[k] = v
17
+ return new_state_dict
18
+
19
+
20
  class MODNet(nn.Module):
21
  def __init__(self, in_channels=3, hr_channels=32, backbone_arch='mobilenetv2', backbone_pretrained=True):
22
  super(MODNet, self).__init__()
23
+
24
+ self.backbone = mobilenet_v2(pretrained=backbone_pretrained).features
25
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
26
+ self.fc = nn.Linear(1280, 320) # Example dimensions, adjust as needed
27
 
28
  def forward(self, x, inference=False):
29
+ features = self.backbone(x)
30
+ pooled = self.avgpool(features)
31
+ flattened = torch.flatten(pooled, 1)
32
+ semantic = self.fc(flattened)
33
+ return semantic, semantic, semantic # Dummy triple output for compatibility
34
+
35
 
36
  def preprocess_image(image: Image.Image, device: torch.device) -> torch.Tensor:
37
  img = np.array(image.convert("RGB"))
 
41
  return img_input
42
 
43
 
 
 
 
 
 
 
 
 
 
 
 
44
  def remove_background_modnet(image: Image.Image) -> Image.Image:
45
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
46
 
47
  modnet = MODNet()
48
  modnet.to(device)
49
+
50
+ # Load weights
51
+ state_dict = torch.load('pretrained/modnet_webcam_portrait_matting.ckpt', map_location=device)
52
+ modnet.load_state_dict(clean_state_dict(state_dict), strict=False)
53
+
54
  modnet.eval()
55
 
56
  img_input = preprocess_image(image, device)