Spaces:
Runtime error
Runtime error
Update model/modnet.py
Browse files- model/modnet.py +26 -14
model/modnet.py
CHANGED
|
@@ -6,14 +6,32 @@ from torchvision.models.mobilenetv2 import mobilenet_v2
|
|
| 6 |
import torch.nn as nn
|
| 7 |
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
class MODNet(nn.Module):
|
| 10 |
def __init__(self, in_channels=3, hr_channels=32, backbone_arch='mobilenetv2', backbone_pretrained=True):
|
| 11 |
super(MODNet, self).__init__()
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
def forward(self, x, inference=False):
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
def preprocess_image(image: Image.Image, device: torch.device) -> torch.Tensor:
|
| 19 |
img = np.array(image.convert("RGB"))
|
|
@@ -23,22 +41,16 @@ def preprocess_image(image: Image.Image, device: torch.device) -> torch.Tensor:
|
|
| 23 |
return img_input
|
| 24 |
|
| 25 |
|
| 26 |
-
def clean_state_dict(state_dict):
|
| 27 |
-
"""Remove 'module.' prefix if present in keys."""
|
| 28 |
-
new_state_dict = {}
|
| 29 |
-
for k, v in state_dict.items():
|
| 30 |
-
if k.startswith('module.'):
|
| 31 |
-
new_state_dict[k[7:]] = v
|
| 32 |
-
else:
|
| 33 |
-
new_state_dict[k] = v
|
| 34 |
-
return new_state_dict
|
| 35 |
-
|
| 36 |
-
|
| 37 |
def remove_background_modnet(image: Image.Image) -> Image.Image:
|
| 38 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 39 |
|
| 40 |
modnet = MODNet()
|
| 41 |
modnet.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
modnet.eval()
|
| 43 |
|
| 44 |
img_input = preprocess_image(image, device)
|
|
|
|
| 6 |
import torch.nn as nn
|
| 7 |
|
| 8 |
|
| 9 |
+
def clean_state_dict(state_dict):
|
| 10 |
+
"""Remove 'module.' prefix if present in keys."""
|
| 11 |
+
new_state_dict = {}
|
| 12 |
+
for k, v in state_dict.items():
|
| 13 |
+
if k.startswith('module.'):
|
| 14 |
+
new_state_dict[k[7:]] = v
|
| 15 |
+
else:
|
| 16 |
+
new_state_dict[k] = v
|
| 17 |
+
return new_state_dict
|
| 18 |
+
|
| 19 |
+
|
| 20 |
class MODNet(nn.Module):
|
| 21 |
def __init__(self, in_channels=3, hr_channels=32, backbone_arch='mobilenetv2', backbone_pretrained=True):
|
| 22 |
super(MODNet, self).__init__()
|
| 23 |
+
|
| 24 |
+
self.backbone = mobilenet_v2(pretrained=backbone_pretrained).features
|
| 25 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 26 |
+
self.fc = nn.Linear(1280, 320) # Example dimensions, adjust as needed
|
| 27 |
|
| 28 |
def forward(self, x, inference=False):
|
| 29 |
+
features = self.backbone(x)
|
| 30 |
+
pooled = self.avgpool(features)
|
| 31 |
+
flattened = torch.flatten(pooled, 1)
|
| 32 |
+
semantic = self.fc(flattened)
|
| 33 |
+
return semantic, semantic, semantic # Dummy triple output for compatibility
|
| 34 |
+
|
| 35 |
|
| 36 |
def preprocess_image(image: Image.Image, device: torch.device) -> torch.Tensor:
|
| 37 |
img = np.array(image.convert("RGB"))
|
|
|
|
| 41 |
return img_input
|
| 42 |
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
def remove_background_modnet(image: Image.Image) -> Image.Image:
|
| 45 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 46 |
|
| 47 |
modnet = MODNet()
|
| 48 |
modnet.to(device)
|
| 49 |
+
|
| 50 |
+
# Load weights
|
| 51 |
+
state_dict = torch.load('pretrained/modnet_webcam_portrait_matting.ckpt', map_location=device)
|
| 52 |
+
modnet.load_state_dict(clean_state_dict(state_dict), strict=False)
|
| 53 |
+
|
| 54 |
modnet.eval()
|
| 55 |
|
| 56 |
img_input = preprocess_image(image, device)
|