Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from model.deep_lab_model.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d | |
| from model.deep_lab_model.aspp import build_aspp | |
| from model.deep_lab_model.decoder import build_decoder | |
| from model.deep_lab_model.backbone import build_backbone | |
| class DeepLab(nn.Module): | |
| def __init__(self, backbone='resnet', output_stride=16, num_classes=21, | |
| sync_bn=True, freeze_bn=False): | |
| super(DeepLab, self).__init__() | |
| if backbone == 'drn': | |
| output_stride = 8 | |
| if sync_bn == True: | |
| BatchNorm = SynchronizedBatchNorm2d | |
| else: | |
| BatchNorm = nn.BatchNorm2d | |
| self.backbone = build_backbone(backbone, output_stride, BatchNorm) | |
| self.aspp = build_aspp(backbone, output_stride, BatchNorm) | |
| self.decoder = build_decoder(num_classes, backbone, BatchNorm) | |
| self.freeze_bn = freeze_bn | |
| def forward(self, input): | |
| x, low_level_feat = self.backbone(input) | |
| x = self.aspp(x) | |
| x = self.decoder(x, low_level_feat) | |
| x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True) | |
| return x | |
| def freeze_bn(self): | |
| for m in self.modules(): | |
| if isinstance(m, SynchronizedBatchNorm2d): | |
| m.eval() | |
| elif isinstance(m, nn.BatchNorm2d): | |
| m.eval() | |
| def get_1x_lr_params(self): | |
| modules = [self.backbone] | |
| for i in range(len(modules)): | |
| for m in modules[i].named_modules(): | |
| if self.freeze_bn: | |
| if isinstance(m[1], nn.Conv2d): | |
| for p in m[1].parameters(): | |
| if p.requires_grad: | |
| yield p | |
| else: | |
| if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ | |
| or isinstance(m[1], nn.BatchNorm2d): | |
| for p in m[1].parameters(): | |
| if p.requires_grad: | |
| yield p | |
| def get_10x_lr_params(self): | |
| modules = [self.aspp, self.decoder] | |
| for i in range(len(modules)): | |
| for m in modules[i].named_modules(): | |
| if self.freeze_bn: | |
| if isinstance(m[1], nn.Conv2d): | |
| for p in m[1].parameters(): | |
| if p.requires_grad: | |
| yield p | |
| else: | |
| if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ | |
| or isinstance(m[1], nn.BatchNorm2d): | |
| for p in m[1].parameters(): | |
| if p.requires_grad: | |
| yield p | |
| if __name__ == "__main__": | |
| model = DeepLab(backbone='mobilenet', output_stride=16) | |
| model.eval() | |
| input = torch.rand(1, 3, 513, 513) | |
| output = model(input) | |
| print(output.size()) | |