Spaces:
Sleeping
Sleeping
| import torch.nn as nn | |
| from torchvision.models import efficientnet_b2, EfficientNet_B2_Weights | |
| class EffnetB2(nn.Module): | |
| def __init__(self, num_classes=3): | |
| super().__init__() | |
| self.model = efficientnet_b2(weights=EfficientNet_B2_Weights.DEFAULT) | |
| for param in self.model.parameters(): | |
| param.requires_grad = False | |
| # print(self.model) | |
| in_features = self.model.classifier.get_submodule("1").in_features | |
| self.model.classifier = nn.Sequential( | |
| nn.Linear(in_features=in_features, out_features=num_classes) | |
| ) | |
| def forward(self, x): | |
| return self.model(x) | |