Spaces:
Runtime error
Runtime error
| import timm | |
| import torch.nn as nn | |
| import albumentations as A | |
| import torch | |
| import cv2 | |
| class CustomNormalization(A.ImageOnlyTransform): | |
| def _norm(self, img): | |
| return img / 255. | |
| def apply(self, img, **params): | |
| return self._norm(img) | |
| def transform_image(image, size): | |
| transforms = [ | |
| A.Resize(size, size, | |
| interpolation=cv2.INTER_NEAREST), | |
| CustomNormalization(p=1), | |
| ] | |
| augs = A.Compose(transforms) | |
| transformed = augs(image=image) | |
| return transformed['image'] | |
| class CustomEfficientNet(nn.Module): | |
| """ | |
| This class defines a custom EfficientNet network. | |
| Parameters | |
| ---------- | |
| target_size : int | |
| Number of units for the output layer. | |
| pretrained : bool | |
| Determine if pretrained weights are used. | |
| Attributes | |
| ---------- | |
| model : nn.Module | |
| EfficientNet model. | |
| """ | |
| def __init__(self, model_name : str = 'efficientnet_b0', | |
| target_size : int = 4, pretrained : bool = True): | |
| super().__init__() | |
| self.model = timm.create_model(model_name, pretrained=pretrained) | |
| # Modify the classifier layer | |
| in_features = self.model.classifier.in_features | |
| self.model.classifier = nn.Sequential( | |
| #nn.Dropout(0.5), | |
| nn.Linear(in_features, 256), | |
| nn.ReLU(), | |
| #nn.Dropout(0.5), | |
| nn.Linear(256, target_size) | |
| ) | |
| def forward(self, x : torch.Tensor) -> torch.Tensor: | |
| x = self.model(x) | |
| return x | |
| class CustomViT(nn.Module): | |
| """ | |
| This class defines a custom ViT network. | |
| Parameters | |
| ---------- | |
| target_size : int | |
| Number of units for the output layer. | |
| pretrained : bool | |
| Determine if pretrained weights are used. | |
| Attributes | |
| ---------- | |
| model : nn.Module | |
| CustomViT model. | |
| """ | |
| def __init__(self, model_name : str = 'vit_base_patch16_224', | |
| target_size : int = 4, pretrained : bool = True): | |
| super().__init__() | |
| self.model = timm.create_model(model_name, | |
| pretrained=pretrained, | |
| num_classes=target_size) | |
| in_features = self.model.head.in_features | |
| self.model.head = nn.Sequential( | |
| #nn.Dropout(0.5), | |
| nn.Linear(in_features, 256), | |
| nn.ReLU(), | |
| nn.Dropout(0.5), | |
| nn.Linear(256, target_size) | |
| ) | |
| def forward(self, x : torch.Tensor) -> torch.Tensor: | |
| x = self.model(x) | |
| return x | |