File size: 1,100 Bytes
5df9707
 
 
 
 
 
 
f8cea41
5df9707
 
 
 
 
 
 
 
f8cea41
5df9707
 
 
 
 
 
 
 
f8cea41
5df9707
 
 
f8cea41
5df9707
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
from model.ResNet18 import ResNet18
from model.CSAT import CSAT
from model.CSATv2 import CSATv2
from torch import nn

img_size = 224
path = r'./CSAT_ImageNet.pth.tar' #or CSAT_RCKD.pth.tar <- for pathological image analysis
model = CSAT(img_size=img_size)
state = torch.load(path, map_location='cpu')
model.load_state_dict(state)
data = torch.zeros((1, 3, img_size, img_size)) #b, c, h, w = 1, 3, 224, 224
model.head = nn.Identity()
output = model(data)#b, c = 1, 176
print(output.shape)

path = r'./ResNet18_RCKD.pth.tar'
model = ResNet18()
state = torch.load(path, map_location='cpu')
model.load_state_dict(state)
data = torch.zeros((1, 3, img_size, img_size)) #b, c, h, w = 1, 3, 224, 224
model.fc = nn.Identity()
output = model(data)#b, c = 1, 512
print(output.shape)

path = r'./CSAT_v2_ImageNet.pth.tar'
model = CSATv2(img_size=img_size)
state = torch.load(path, map_location='cpu')
model.load_state_dict(state['state_dict'])
data = torch.zeros((1, 3, img_size, img_size)) #b, c, h, w = 1, 3, 512, 512
model.fc = nn.Identity()
output = model(data)#b, c = 1, 512
print(output.shape)