Spaces:
Runtime error
Runtime error
Update model/backbones/mobilenetv2.py
Browse files
model/backbones/mobilenetv2.py
CHANGED
|
@@ -4,18 +4,28 @@ from torchvision.models import mobilenet_v2, MobileNet_V2_Weights
|
|
| 4 |
class MobileNetV2Backbone(nn.Module):
|
| 5 |
def __init__(self):
|
| 6 |
super(MobileNetV2Backbone, self).__init__()
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
def forward(self, x):
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
-
# Добавим SUPPORTED_BACKBONES прямо сюда (или перенеси в backbones/__init__.py, если нужно)
|
| 16 |
SUPPORTED_BACKBONES = {
|
| 17 |
'mobilenetv2': lambda _=None: MobileNetV2Backbone()
|
| 18 |
-
}
|
| 19 |
|
| 20 |
|
| 21 |
def load_modnet_checkpoint(modnet, path, device):
|
|
@@ -23,10 +33,10 @@ def load_modnet_checkpoint(modnet, path, device):
|
|
| 23 |
new_state_dict = {}
|
| 24 |
for k, v in state_dict.items():
|
| 25 |
if k.startswith('module.'):
|
| 26 |
-
k = k[len('module.'):]
|
| 27 |
new_state_dict[k] = v
|
| 28 |
|
| 29 |
missing, unexpected = modnet.load_state_dict(new_state_dict, strict=False)
|
| 30 |
print("MISSING:", missing)
|
| 31 |
print("UNEXPECTED:", unexpected)
|
| 32 |
-
return modnet
|
|
|
|
| 4 |
class MobileNetV2Backbone(nn.Module):
|
| 5 |
def __init__(self):
|
| 6 |
super(MobileNetV2Backbone, self).__init__()
|
| 7 |
+
backbone = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT)
|
| 8 |
+
features = backbone.features
|
| 9 |
+
|
| 10 |
+
# Слои, откуда берём фичи (по оригинальной архитектуре MODNet)
|
| 11 |
+
self.enc2x = features[:4] # Примерно до блока 3
|
| 12 |
+
self.enc4x = features[4:7] # Примерно до блока 6
|
| 13 |
+
self.enc8x = features[7:14] # Примерно до блока 13
|
| 14 |
+
self.enc32x = features[14:] # Оставшиеся блоки
|
| 15 |
+
|
| 16 |
+
self.enc_channels = [24, 32, 96, 320]
|
| 17 |
|
| 18 |
def forward(self, x):
|
| 19 |
+
out2x = self.enc2x(x)
|
| 20 |
+
out4x = self.enc4x(out2x)
|
| 21 |
+
out8x = self.enc8x(out4x)
|
| 22 |
+
out32x = self.enc32x(out8x)
|
| 23 |
+
return [out2x, out4x, out8x, out32x]
|
| 24 |
|
| 25 |
|
|
|
|
| 26 |
SUPPORTED_BACKBONES = {
|
| 27 |
'mobilenetv2': lambda _=None: MobileNetV2Backbone()
|
| 28 |
+
}
|
| 29 |
|
| 30 |
|
| 31 |
def load_modnet_checkpoint(modnet, path, device):
|
|
|
|
| 33 |
new_state_dict = {}
|
| 34 |
for k, v in state_dict.items():
|
| 35 |
if k.startswith('module.'):
|
| 36 |
+
k = k[len('module.'):] # удаляем префикс, если обучалось с DataParallel
|
| 37 |
new_state_dict[k] = v
|
| 38 |
|
| 39 |
missing, unexpected = modnet.load_state_dict(new_state_dict, strict=False)
|
| 40 |
print("MISSING:", missing)
|
| 41 |
print("UNEXPECTED:", unexpected)
|
| 42 |
+
return modnet
|