bluspater commited on
Commit
782b77c
·
verified ·
1 Parent(s): c6423fb

Update model/backbones/mobilenetv2.py

Browse files
Files changed (1) hide show
  1. model/backbones/mobilenetv2.py +18 -8
model/backbones/mobilenetv2.py CHANGED
@@ -4,18 +4,28 @@ from torchvision.models import mobilenet_v2, MobileNet_V2_Weights
4
  class MobileNetV2Backbone(nn.Module):
5
  def __init__(self):
6
  super(MobileNetV2Backbone, self).__init__()
7
- # Используем актуальный способ загрузки предобученной модели
8
- self.model = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
9
- self.enc_channels = [24, 32, 96, 320] # Каналы с разных уровней feature extractor
 
 
 
 
 
 
 
10
 
11
  def forward(self, x):
12
- return self.model(x)
 
 
 
 
13
 
14
 
15
- # Добавим SUPPORTED_BACKBONES прямо сюда (или перенеси в backbones/__init__.py, если нужно)
16
  SUPPORTED_BACKBONES = {
17
  'mobilenetv2': lambda _=None: MobileNetV2Backbone()
18
- } # _=None позволяет игнорировать входной параметр
19
 
20
 
21
  def load_modnet_checkpoint(modnet, path, device):
@@ -23,10 +33,10 @@ def load_modnet_checkpoint(modnet, path, device):
23
  new_state_dict = {}
24
  for k, v in state_dict.items():
25
  if k.startswith('module.'):
26
- k = k[len('module.'):]
27
  new_state_dict[k] = v
28
 
29
  missing, unexpected = modnet.load_state_dict(new_state_dict, strict=False)
30
  print("MISSING:", missing)
31
  print("UNEXPECTED:", unexpected)
32
- return modnet
 
4
  class MobileNetV2Backbone(nn.Module):
5
  def __init__(self):
6
  super(MobileNetV2Backbone, self).__init__()
7
+ backbone = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT)
8
+ features = backbone.features
9
+
10
+ # Слои, откуда берём фичи (по оригинальной архитектуре MODNet)
11
+ self.enc2x = features[:4] # Примерно до блока 3
12
+ self.enc4x = features[4:7] # Примерно до блока 6
13
+ self.enc8x = features[7:14] # Примерно до блока 13
14
+ self.enc32x = features[14:] # Оставшиеся блоки
15
+
16
+ self.enc_channels = [24, 32, 96, 320]
17
 
18
  def forward(self, x):
19
+ out2x = self.enc2x(x)
20
+ out4x = self.enc4x(out2x)
21
+ out8x = self.enc8x(out4x)
22
+ out32x = self.enc32x(out8x)
23
+ return [out2x, out4x, out8x, out32x]
24
 
25
 
 
26
  SUPPORTED_BACKBONES = {
27
  'mobilenetv2': lambda _=None: MobileNetV2Backbone()
28
+ }
29
 
30
 
31
  def load_modnet_checkpoint(modnet, path, device):
 
33
  new_state_dict = {}
34
  for k, v in state_dict.items():
35
  if k.startswith('module.'):
36
+ k = k[len('module.'):] # удаляем префикс, если обучалось с DataParallel
37
  new_state_dict[k] = v
38
 
39
  missing, unexpected = modnet.load_state_dict(new_state_dict, strict=False)
40
  print("MISSING:", missing)
41
  print("UNEXPECTED:", unexpected)
42
+ return modnet