import torch from model.ResNet18 import ResNet18 from model.CSAT import CSAT from model.CSATv2 import CSATv2 from torch import nn img_size = 224 path = r'./CSAT_ImageNet.pth.tar' #or CSAT_RCKD.pth.tar <- for pathological image analysis model = CSAT(img_size=img_size) state = torch.load(path, map_location='cpu') model.load_state_dict(state) data = torch.zeros((1, 3, img_size, img_size)) #b, c, h, w = 1, 3, 224, 224 model.head = nn.Identity() output = model(data)#b, c = 1, 176 print(output.shape) path = r'./ResNet18_RCKD.pth.tar' model = ResNet18() state = torch.load(path, map_location='cpu') model.load_state_dict(state) data = torch.zeros((1, 3, img_size, img_size)) #b, c, h, w = 1, 3, 224, 224 model.fc = nn.Identity() output = model(data)#b, c = 1, 512 print(output.shape) path = r'./CSAT_v2_ImageNet.pth.tar' model = CSATv2(img_size=img_size) state = torch.load(path, map_location='cpu') model.load_state_dict(state['state_dict']) data = torch.zeros((1, 3, img_size, img_size)) #b, c, h, w = 1, 3, 512, 512 model.fc = nn.Identity() output = model(data)#b, c = 1, 512 print(output.shape)