File size: 1,100 Bytes
5df9707 f8cea41 5df9707 f8cea41 5df9707 f8cea41 5df9707 f8cea41 5df9707 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import torch
from model.ResNet18 import ResNet18
from model.CSAT import CSAT
from model.CSATv2 import CSATv2
from torch import nn
img_size = 224
path = r'./CSAT_ImageNet.pth.tar' #or CSAT_RCKD.pth.tar <- for pathological image analysis
model = CSAT(img_size=img_size)
state = torch.load(path, map_location='cpu')
model.load_state_dict(state)
data = torch.zeros((1, 3, img_size, img_size)) #b, c, h, w = 1, 3, 224, 224
model.head = nn.Identity()
output = model(data)#b, c = 1, 176
print(output.shape)
path = r'./ResNet18_RCKD.pth.tar'
model = ResNet18()
state = torch.load(path, map_location='cpu')
model.load_state_dict(state)
data = torch.zeros((1, 3, img_size, img_size)) #b, c, h, w = 1, 3, 224, 224
model.fc = nn.Identity()
output = model(data)#b, c = 1, 512
print(output.shape)
path = r'./CSAT_v2_ImageNet.pth.tar'
model = CSATv2(img_size=img_size)
state = torch.load(path, map_location='cpu')
model.load_state_dict(state['state_dict'])
data = torch.zeros((1, 3, img_size, img_size)) #b, c, h, w = 1, 3, 512, 512
model.fc = nn.Identity()
output = model(data)#b, c = 1, 512
print(output.shape) |