Spaces:
Runtime error
Runtime error
| import torch | |
| import os | |
| import torch.nn as nn | |
| import torchvision.datasets as datasets | |
| import torchvision.transforms as transforms | |
| import torchvision.models as models | |
| os.system("gdown --id 1qO2OLR7skDibo1LaMKD3CiOl_jaCTZ0h") | |
| class JHARMNet(nn.Module): | |
| def __init__(self, pretrained_model, num_classes): | |
| super().__init__() | |
| self.premodel = pretrained_model | |
| self.premodel.fc = nn.Linear(2048, num_classes) | |
| nn.init.xavier_uniform_(self.premodel.fc.weight) | |
| def forward(self, x): | |
| out = self.premodel(x) | |
| return out | |
| class HiddenLayer(nn.Module): | |
| def __init__(self, pretrained_model): | |
| super().__init__() | |
| self.premodel = pretrained_model | |
| self.new_layer = nn.Sequential( | |
| nn.Linear(1000, 512), | |
| nn.LeakyReLU(), | |
| nn.Linear(512, 512), | |
| nn.LeakyReLU(), | |
| nn.Linear(512, 256), | |
| nn.LeakyReLU(), | |
| nn.Linear(256, 10) | |
| ) | |
| def forward(self, x): | |
| out = self.premodel(x) | |
| out_new_layer = self.new_layer(out) | |
| return out_new_layer | |
| resnet = models.resnet50(pretrained=True) | |
| x = torch.randn((2, 3, 32, 32)) | |
| print(x.shape) | |
| model_check = HiddenLayer(resnet) | |
| model_check.load_state_dict(torch.load("CIFAR_end_hll.pt")) | |
| model_check.eval() | |
| print(model_check.forward(x).shape) | |