Spaces:
Runtime error
Runtime error
| import timm | |
| import torch.nn as nn | |
| import torch | |
| def get_efficientnet(model_name): | |
| model = timm.create_model(model_name, pretrained=True) | |
| return model | |
| class CustomEfficientNet(nn.Module): | |
| """ | |
| This class defines a custom EfficientNet network. | |
| Parameters | |
| ---------- | |
| target_size : int | |
| Number of units for the output layer. | |
| pretrained : bool | |
| Determine if pretrained weights are used. | |
| Attributes | |
| ---------- | |
| model : nn.Module | |
| EfficientNet model. | |
| """ | |
| def __init__(self, model_name : str = 'efficientnet_b0', | |
| target_size : int = 4, pretrained : bool = True): | |
| super().__init__() | |
| self.model = timm.create_model(model_name, pretrained=pretrained) | |
| # Modify the classifier layer | |
| in_features = self.model.classifier.in_features | |
| self.model.classifier = nn.Sequential( | |
| #nn.Dropout(0.5), | |
| nn.Linear(in_features, 256), | |
| nn.ReLU(), | |
| #nn.Dropout(0.5), | |
| nn.Linear(256, target_size) | |
| ) | |
| def forward(self, x : torch.Tensor) -> torch.Tensor: | |
| x = self.model(x) | |
| return x | |