Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| # import torch.nn.utils.prune as prune | |
| import torchvision.models as models | |
| import torchvision | |
| # from torchsummary import summary | |
| class MobileNetV2FeatureExtractor(nn.Module): | |
| def __init__(self): | |
| super(MobileNetV2FeatureExtractor, self).__init__() | |
| self.model = torchvision.models.detection.fasterrcnn_resnet50_fpn_v2(pretrained=False) | |
| for param in self.model.parameters(): | |
| param.requires_grad = True | |
| self.model = self.model.backbone | |
| def forward(self, x): | |
| return self.model(x) | |
| class GlobalAvgPool2D(nn.Module): | |
| def __init__(self): | |
| super(GlobalAvgPool2D, self).__init__() | |
| def forward(self, x): | |
| tensor = x['0'] | |
| return torch.mean(tensor.view(tensor.size(0), tensor.size(1), -1), dim=2) | |
| class LDRNet_fasterrcnn(nn.Module): | |
| def __init__(self, points_size=100, classification_list=[1]): | |
| super(LDRNet_fasterrcnn, self).__init__() | |
| self.points_size = points_size | |
| self.classification_list = classification_list | |
| self.backbone = MobileNetV2FeatureExtractor() | |
| if len(classification_list) > 0: | |
| class_size = sum(self.classification_list) | |
| else: | |
| class_size = 0 | |
| self.global_pool = GlobalAvgPool2D() | |
| # self.dropout = nn.Dropout(p=0.3) | |
| self.corner = nn.Linear(256, 8) | |
| self.border = nn.Linear(256, (points_size - 4) * 2) | |
| self.cls = nn.Linear(256, class_size + len(self.classification_list)) | |
| def forward(self, x): | |
| x = self.backbone(x) | |
| x = self.global_pool(x) | |
| # x = self.dropout(x) | |
| corner_output = self.corner(x) | |
| border_output = self.border(x) | |
| cls_output = self.cls(x) | |
| return corner_output, border_output, cls_output | |
| if __name__ == "__main__": | |
| import torch | |
| # from torchsummary import summary | |
| xx = torch.zeros((1, 3, 224, 224)) | |
| model = LDRNet_fasterrcnn() | |
| print(model) | |
| y = model(xx) | |
| for name, module in model.named_modules(): | |
| if isinstance(module, torch.nn.Conv2d): | |
| prune.l1_unstructured(module, name='weight', amount=0.2) | |
| elif isinstance(module, torch.nn.Linear): | |
| prune.l1_unstructured(module, name='weight', amount=0.4) | |
| # print(y[0].detach().numpy()[0]) | |
| # summary(model,input_size=(3, 224, 224)) | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| total_trainable_params = sum( | |
| p.numel() for p in model.parameters() if p.requires_grad | |
| ) | |
| print(f"[INFO]: {total_params:,} total parameters.") | |
| print(f"[INFO]: {total_trainable_params:,} trainable parameters.") | |