| # -*- coding: utf-8 -*- | |
| # @Time : 2024/8/6 下午3:44 | |
| # @Author : xiaoshun | |
| # @Email : 3038523973@qq.com | |
| # @File : unetmobv2.py | |
| # @Software: PyCharm | |
| import segmentation_models_pytorch as smp | |
| import torch | |
| from torch import nn as nn | |
| class UNetMobV2(nn.Module): | |
| def __init__(self,num_classes,in_channels=3): | |
| super().__init__() | |
| self.backbone = smp.Unet( | |
| encoder_name='mobilenet_v2', | |
| encoder_weights=None, | |
| in_channels=in_channels, | |
| classes=num_classes, | |
| ) | |
| def forward(self, x): | |
| x = self.backbone(x) | |
| return x | |
| if __name__ == '__main__': | |
| fake_image = torch.rand(1, 3, 224, 224) | |
| model = UNetMobV2(num_classes=2) | |
| output = model(fake_image) | |
| print(output.size()) |