Spaces:
Sleeping
Sleeping
| from torch import nn | |
| import timm | |
| class EfficientNet(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.efficientnet = timm.create_model(model_name="efficientnet_b0", pretrained=True, num_classes=25) | |
| """# Set requires_grad to False for all parameters except the output layer | |
| for name, param in self.efficientnet.named_parameters(): | |
| if not name.startswith('classifier'): | |
| param.requires_grad = False""" | |
| # print number of parameters including final layer | |
| trainable_params = sum(p.numel() for p in self.efficientnet.parameters() if p.requires_grad) | |
| total_params = sum(p.numel() for p in self.efficientnet.parameters()) | |
| """print("Efficientnet_b0 with 25 classes initialized") | |
| print(f"Trainable parameters: {trainable_params}") | |
| print(f"Total parameters: {total_params}")""" | |
| def forward(self, x): | |
| return self.efficientnet(x) | |