|
|
import torch |
|
|
from model.ResNet18 import ResNet18 |
|
|
from model.CSAT import CSAT |
|
|
from model.CSATv2 import CSATv2 |
|
|
from torch import nn |
|
|
|
|
|
img_size = 224 |
|
|
path = r'./CSAT_ImageNet.pth.tar' |
|
|
model = CSAT(img_size=img_size) |
|
|
state = torch.load(path, map_location='cpu') |
|
|
model.load_state_dict(state) |
|
|
data = torch.zeros((1, 3, img_size, img_size)) |
|
|
model.head = nn.Identity() |
|
|
output = model(data) |
|
|
print(output.shape) |
|
|
|
|
|
path = r'./ResNet18_RCKD.pth.tar' |
|
|
model = ResNet18() |
|
|
state = torch.load(path, map_location='cpu') |
|
|
model.load_state_dict(state) |
|
|
data = torch.zeros((1, 3, img_size, img_size)) |
|
|
model.fc = nn.Identity() |
|
|
output = model(data) |
|
|
print(output.shape) |
|
|
|
|
|
path = r'./CSAT_v2_ImageNet.pth.tar' |
|
|
model = CSATv2(img_size=img_size) |
|
|
state = torch.load(path, map_location='cpu') |
|
|
model.load_state_dict(state['state_dict']) |
|
|
data = torch.zeros((1, 3, img_size, img_size)) |
|
|
model.fc = nn.Identity() |
|
|
output = model(data) |
|
|
print(output.shape) |