Spaces:
Runtime error
Runtime error
Update model/modnet.py
Browse files- model/modnet.py +3 -11
model/modnet.py
CHANGED
|
@@ -9,13 +9,11 @@ import torch.nn as nn
|
|
| 9 |
class MODNet(nn.Module):
|
| 10 |
def __init__(self):
|
| 11 |
super(MODNet, self).__init__()
|
| 12 |
-
self.backbone =
|
| 13 |
|
| 14 |
def forward(self, x, inference=False):
|
| 15 |
-
#
|
| 16 |
-
x
|
| 17 |
-
x_avg = torch.mean(x, dim=1, keepdim=True)
|
| 18 |
-
return x_avg, x_avg, x_avg # dummy semantic, detail, matte
|
| 19 |
|
| 20 |
def preprocess_image(image: Image.Image, device: torch.device) -> torch.Tensor:
|
| 21 |
img = np.array(image.convert("RGB"))
|
|
@@ -40,12 +38,6 @@ def remove_background_modnet(image: Image.Image) -> Image.Image:
|
|
| 40 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 41 |
|
| 42 |
modnet = MODNet()
|
| 43 |
-
# NOTE: Commenting out checkpoint loading to avoid mismatch with dummy structure
|
| 44 |
-
# ckpt_path = 'pretrained/modnet_webcam_portrait_matting.ckpt'
|
| 45 |
-
# state_dict = torch.load(ckpt_path, map_location=device)
|
| 46 |
-
# state_dict = clean_state_dict(state_dict)
|
| 47 |
-
# modnet.load_state_dict(state_dict)
|
| 48 |
-
|
| 49 |
modnet.to(device)
|
| 50 |
modnet.eval()
|
| 51 |
|
|
|
|
| 9 |
class MODNet(nn.Module):
|
| 10 |
def __init__(self):
|
| 11 |
super(MODNet, self).__init__()
|
| 12 |
+
self.backbone = nn.Identity() # Replaced with identity for fast test
|
| 13 |
|
| 14 |
def forward(self, x, inference=False):
|
| 15 |
+
# Fast dummy forward to avoid GPU/memory bottlenecks
|
| 16 |
+
return x, x, x # dummy semantic, detail, matte
|
|
|
|
|
|
|
| 17 |
|
| 18 |
def preprocess_image(image: Image.Image, device: torch.device) -> torch.Tensor:
|
| 19 |
img = np.array(image.convert("RGB"))
|
|
|
|
| 38 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 39 |
|
| 40 |
modnet = MODNet()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
modnet.to(device)
|
| 42 |
modnet.eval()
|
| 43 |
|