bluspater commited on
Commit
93a6bfd
·
verified ·
1 Parent(s): 943c782

Update model/backbones/mobilenetv2.py

Browse files
Files changed (1) hide show
  1. model/backbones/mobilenetv2.py +6 -14
model/backbones/mobilenetv2.py CHANGED
@@ -1,20 +1,12 @@
1
  import torch.nn as nn
2
- from torchvision.models import mobilenet_v2
3
-
4
-
5
-
6
 
7
  class MobileNetV2Backbone(nn.Module):
8
  def __init__(self):
9
- super().__init__()
10
- model = mobilenet_v2(pretrained=True)
11
- self.model = backbone #model.features
12
- self.enc_channels = [24, 32, 96, 160, 320] # фиксированные выходы энкодера
13
 
14
  def forward(self, x):
15
- features = []
16
- for idx, layer in enumerate(self.model):
17
- x = layer(x)
18
- if idx in {2, 4, 7, 13}:
19
- features.append(x)
20
- return features
 
1
  import torch.nn as nn
2
+ from torchvision.models import mobilenet_v2, MobileNet_V2_Weights
 
 
 
3
 
4
  class MobileNetV2Backbone(nn.Module):
5
  def __init__(self):
6
+ super(MobileNetV2Backbone, self).__init__()
7
+ # Используем актуальный способ загрузки предобученной модели
8
+ self.model = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
 
9
 
10
  def forward(self, x):
11
+ return self.model(x)
12
+