remove-bg-modnet / model /backbones /mobilenetv2.py
bluspater's picture
Update model/backbones/mobilenetv2.py
3b09131 verified
import torch.nn as nn
from torchvision.models import mobilenet_v2, MobileNet_V2_Weights
class MobileNetV2Backbone(nn.Module):
def __init__(self):
super(MobileNetV2Backbone, self).__init__()
backbone = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT)
features = backbone.features
# Слои, откуда берём фичи (по оригинальной архитектуре MODNet)
self.enc2x = features[:4] # Примерно до блока 3
self.enc4x = features[4:7] # Примерно до блока 6
self.enc8x = features[7:14] # Примерно до блока 13
self.enc32x = features[14:] # Оставшиеся блоки
self.enc_channels = [24, 32, 96, 1280] # Обновлено: финальный слой имеет 1280 каналов
def forward(self, x):
out2x = self.enc2x(x)
out4x = self.enc4x(out2x)
out8x = self.enc8x(out4x)
out32x = self.enc32x(out8x)
return [out2x, out4x, out8x, out32x]
SUPPORTED_BACKBONES = {
'mobilenetv2': lambda _=None: MobileNetV2Backbone()
}
def load_modnet_checkpoint(modnet, path, device):
state_dict = torch.load(path, map_location=device)
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('module.'):
k = k[len('module.'):] # удаляем префикс, если обучалось с DataParallel
new_state_dict[k] = v
missing, unexpected = modnet.load_state_dict(new_state_dict, strict=False)
print("MISSING:", missing)
print("UNEXPECTED:", unexpected)
return modnet