CSATv2 / example.py
sosigikiller's picture
change_folder
f8cea41
raw
history blame
1.1 kB
import torch
from model.ResNet18 import ResNet18
from model.CSAT import CSAT
from model.CSATv2 import CSATv2
from torch import nn
img_size = 224
path = r'./CSAT_ImageNet.pth.tar' #or CSAT_RCKD.pth.tar <- for pathological image analysis
model = CSAT(img_size=img_size)
state = torch.load(path, map_location='cpu')
model.load_state_dict(state)
data = torch.zeros((1, 3, img_size, img_size)) #b, c, h, w = 1, 3, 224, 224
model.head = nn.Identity()
output = model(data)#b, c = 1, 176
print(output.shape)
path = r'./ResNet18_RCKD.pth.tar'
model = ResNet18()
state = torch.load(path, map_location='cpu')
model.load_state_dict(state)
data = torch.zeros((1, 3, img_size, img_size)) #b, c, h, w = 1, 3, 224, 224
model.fc = nn.Identity()
output = model(data)#b, c = 1, 512
print(output.shape)
path = r'./CSAT_v2_ImageNet.pth.tar'
model = CSATv2(img_size=img_size)
state = torch.load(path, map_location='cpu')
model.load_state_dict(state['state_dict'])
data = torch.zeros((1, 3, img_size, img_size)) #b, c, h, w = 1, 3, 512, 512
model.fc = nn.Identity()
output = model(data)#b, c = 1, 512
print(output.shape)