import torch a = torch.load('models/pretrained_vit.pth', map_location='cpu') print(a)