Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| from torchvision.models import mobilenet_v2, MobileNet_V2_Weights | |
| class MobileNetV2Backbone(nn.Module): | |
| def __init__(self): | |
| super(MobileNetV2Backbone, self).__init__() | |
| backbone = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT) | |
| features = backbone.features | |
| # Слои, откуда берём фичи (по оригинальной архитектуре MODNet) | |
| self.enc2x = features[:4] # Примерно до блока 3 | |
| self.enc4x = features[4:7] # Примерно до блока 6 | |
| self.enc8x = features[7:14] # Примерно до блока 13 | |
| self.enc32x = features[14:] # Оставшиеся блоки | |
| self.enc_channels = [24, 32, 96, 1280] # Обновлено: финальный слой имеет 1280 каналов | |
| def forward(self, x): | |
| out2x = self.enc2x(x) | |
| out4x = self.enc4x(out2x) | |
| out8x = self.enc8x(out4x) | |
| out32x = self.enc32x(out8x) | |
| return [out2x, out4x, out8x, out32x] | |
| SUPPORTED_BACKBONES = { | |
| 'mobilenetv2': lambda _=None: MobileNetV2Backbone() | |
| } | |
| def load_modnet_checkpoint(modnet, path, device): | |
| state_dict = torch.load(path, map_location=device) | |
| new_state_dict = {} | |
| for k, v in state_dict.items(): | |
| if k.startswith('module.'): | |
| k = k[len('module.'):] # удаляем префикс, если обучалось с DataParallel | |
| new_state_dict[k] = v | |
| missing, unexpected = modnet.load_state_dict(new_state_dict, strict=False) | |
| print("MISSING:", missing) | |
| print("UNEXPECTED:", unexpected) | |
| return modnet |