bluspater commited on
Commit
b2f70fb
·
verified ·
1 Parent(s): 8b71584

Update model/backbones.py

Browse files
Files changed (1) hide show
  1. model/backbones.py +12 -7
model/backbones.py CHANGED
@@ -1,14 +1,19 @@
1
  import torch.nn as nn
2
- import torchvision.models as models
3
 
 
 
 
 
 
 
4
 
5
- def _build_mobilenetv2_backbone():
6
- backbone = models.mobilenet_v2(pretrained=True).features
7
- backbone.enc_channels = [24, 32, 96, 160, 320]
8
- return backbone
9
 
10
 
 
11
  SUPPORTED_BACKBONES = {
12
- 'mobilenetv2': lambda in_channels: _build_mobilenetv2_backbone()
13
- }
14
 
 
1
  import torch.nn as nn
2
+ from torchvision.models import mobilenet_v2, MobileNet_V2_Weights
3
 
4
+ class MobileNetV2Backbone(nn.Module):
5
+ def __init__(self):
6
+ super(MobileNetV2Backbone, self).__init__()
7
+ # Используем актуальный способ загрузки предобученной модели
8
+ self.model = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
9
+ self.enc_channels = [24, 32, 96, 320] # Каналы с разных уровней feature extractor
10
 
11
+ def forward(self, x):
12
+ return self.model(x)
 
 
13
 
14
 
15
+ # Добавим SUPPORTED_BACKBONES прямо сюда (или перенеси в backbones/__init__.py, если нужно)
16
  SUPPORTED_BACKBONES = {
17
+ 'mobilenetv2': lambda _=None: MobileNetV2Backbone()
18
+ } # _=None позволяет игнорировать входной параметр
19