bluspater commited on
Commit
3405f42
·
verified ·
1 Parent(s): ac254a5

Update model/modnet.py

Browse files
Files changed (1) hide show
  1. model/modnet.py +3 -11
model/modnet.py CHANGED
@@ -9,13 +9,11 @@ import torch.nn as nn
9
  class MODNet(nn.Module):
10
  def __init__(self):
11
  super(MODNet, self).__init__()
12
- self.backbone = mobilenet_v2(pretrained=True).features
13
 
14
  def forward(self, x, inference=False):
15
- # Dummy pass-through forward for testing purposes.
16
- x = self.backbone(x)
17
- x_avg = torch.mean(x, dim=1, keepdim=True)
18
- return x_avg, x_avg, x_avg # dummy semantic, detail, matte
19
 
20
  def preprocess_image(image: Image.Image, device: torch.device) -> torch.Tensor:
21
  img = np.array(image.convert("RGB"))
@@ -40,12 +38,6 @@ def remove_background_modnet(image: Image.Image) -> Image.Image:
40
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
41
 
42
  modnet = MODNet()
43
- # NOTE: Commenting out checkpoint loading to avoid mismatch with dummy structure
44
- # ckpt_path = 'pretrained/modnet_webcam_portrait_matting.ckpt'
45
- # state_dict = torch.load(ckpt_path, map_location=device)
46
- # state_dict = clean_state_dict(state_dict)
47
- # modnet.load_state_dict(state_dict)
48
-
49
  modnet.to(device)
50
  modnet.eval()
51
 
 
9
  class MODNet(nn.Module):
10
  def __init__(self):
11
  super(MODNet, self).__init__()
12
+ self.backbone = nn.Identity() # Replaced with identity for fast test
13
 
14
  def forward(self, x, inference=False):
15
+ # Fast dummy forward to avoid GPU/memory bottlenecks
16
+ return x, x, x # dummy semantic, detail, matte
 
 
17
 
18
  def preprocess_image(image: Image.Image, device: torch.device) -> torch.Tensor:
19
  img = np.array(image.convert("RGB"))
 
38
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39
 
40
  modnet = MODNet()
 
 
 
 
 
 
41
  modnet.to(device)
42
  modnet.eval()
43