detect4corners / model_fasterrcnn.py
letrunglinh's picture
Upload 2 files
3eda55a
import torch
import torch.nn as nn
# import torch.nn.utils.prune as prune
import torchvision.models as models
import torchvision
# from torchsummary import summary
class MobileNetV2FeatureExtractor(nn.Module):
def __init__(self):
super(MobileNetV2FeatureExtractor, self).__init__()
self.model = torchvision.models.detection.fasterrcnn_resnet50_fpn_v2(pretrained=False)
for param in self.model.parameters():
param.requires_grad = True
self.model = self.model.backbone
def forward(self, x):
return self.model(x)
class GlobalAvgPool2D(nn.Module):
def __init__(self):
super(GlobalAvgPool2D, self).__init__()
def forward(self, x):
tensor = x['0']
return torch.mean(tensor.view(tensor.size(0), tensor.size(1), -1), dim=2)
class LDRNet_fasterrcnn(nn.Module):
def __init__(self, points_size=100, classification_list=[1]):
super(LDRNet_fasterrcnn, self).__init__()
self.points_size = points_size
self.classification_list = classification_list
self.backbone = MobileNetV2FeatureExtractor()
if len(classification_list) > 0:
class_size = sum(self.classification_list)
else:
class_size = 0
self.global_pool = GlobalAvgPool2D()
# self.dropout = nn.Dropout(p=0.3)
self.corner = nn.Linear(256, 8)
self.border = nn.Linear(256, (points_size - 4) * 2)
self.cls = nn.Linear(256, class_size + len(self.classification_list))
def forward(self, x):
x = self.backbone(x)
x = self.global_pool(x)
# x = self.dropout(x)
corner_output = self.corner(x)
border_output = self.border(x)
cls_output = self.cls(x)
return corner_output, border_output, cls_output
if __name__ == "__main__":
import torch
# from torchsummary import summary
xx = torch.zeros((1, 3, 224, 224))
model = LDRNet_fasterrcnn()
print(model)
y = model(xx)
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
prune.l1_unstructured(module, name='weight', amount=0.2)
elif isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, name='weight', amount=0.4)
# print(y[0].detach().numpy()[0])
# summary(model,input_size=(3, 224, 224))
total_params = sum(p.numel() for p in model.parameters())
total_trainable_params = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
print(f"[INFO]: {total_params:,} total parameters.")
print(f"[INFO]: {total_trainable_params:,} trainable parameters.")