bluspater commited on
Commit
afbb618
·
verified ·
1 Parent(s): 6893f8e

Create model/backbones/mobilenetv2.py

Browse files
Files changed (1) hide show
  1. model/backbones/mobilenetv2.py +18 -0
model/backbones/mobilenetv2.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from torchvision.models import mobilenet_v2
3
+
4
+
5
+ class MobileNetV2Backbone(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+ model = mobilenet_v2(pretrained=True)
9
+ self.model = model.features
10
+ self.enc_channels = [24, 32, 96, 320] # фиксированные выходы энкодера
11
+
12
+ def forward(self, x):
13
+ features = []
14
+ for idx, layer in enumerate(self.model):
15
+ x = layer(x)
16
+ if idx in {2, 4, 7, 13}:
17
+ features.append(x)
18
+ return features