Spaces:
Build error
Build error
| import timm | |
| import torch.nn as nn | |
| class Build_Custom_Model(nn.Module): | |
| def __init__(self, model_name, target_size, pretrained=False): | |
| super().__init__() | |
| self.model = timm.create_model(model_name, pretrained=pretrained, in_chans=1) | |
| if(model_name=="vit_base_patch16_224" or model_name=="swin_base_patch4_window7_224"): | |
| self.n_features = self.model.head.in_features | |
| self.model.head = nn.Linear(self.n_features, target_size) | |
| if(model_name=="resnet34d"): | |
| self.n_features = self.model.fc.in_features | |
| self.model.fc = nn.Linear(self.n_features, target_size) | |
| if(model_name=="resnet18d"): | |
| self.n_features = self.model.fc.in_features | |
| self.model.fc = nn.Linear(self.n_features, target_size) | |
| if(model_name=="tf_efficientnet_b7_ns"): | |
| self.n_features = self.model.classifier.in_features | |
| self.model.classifier = nn.Linear(self.n_features, target_size) | |
| if(model_name=="tf_efficientnet_b0_ns"): | |
| self.n_features = self.model.classifier.in_features | |
| self.model.classifier = nn.Linear(self.n_features, target_size) | |
| if(model_name=="tf_efficientnet_lite0"): | |
| self.n_features = self.model.classifier.in_features | |
| self.model.classifier = nn.Linear(self.n_features, target_size) | |
| if(model_name=="mobilenetv2_050"): | |
| self.n_features = self.model.classifier.in_features | |
| self.model.classifier = nn.Linear(self.n_features, target_size) | |
| if(model_name=="eca_nfnet_l0"): | |
| self.n_features = self.model.head.fc.in_features | |
| self.model.head.fc = nn.Linear(self.n_features, target_size) | |
| def forward(self, x): | |
| output = self.model(x) | |
| return output | |
| def reshape_transform(tensor, height=7, width=7): | |
| result = tensor.reshape(tensor.size(0), | |
| height, width, tensor.size(2)) | |
| # Bring the channels to the first dimension, | |
| # like in CNNs. | |
| result = result.transpose(2, 3).transpose(1, 2) | |
| return result | |