File size: 1,247 Bytes
905cc0d ec378c3 905cc0d ec378c3 905cc0d ec378c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
from model import _load_one_model, TowPartModel, BrainEncodingModel
from config_utils import load_from_yaml
import torch
subject = 'subj01'
cfg_path = "/workspace/model_packed2/config.yaml"
model_path1 = f"/workspace/model_packed2/ckpts/{subject}_part1.pth"
model_path2 = f"/workspace/model_packed2/ckpts/{subject}_part2.pth"
# model1 is for vertices with high noise ceiling (nsdgeneral)
# model2 is for vertices from the rest of the brain
model1: BrainEncodingModel = _load_one_model(model_path1, subject, cfg_path)
model2: BrainEncodingModel = _load_one_model(model_path2, subject, cfg_path)
# voxel_indices is a list of indices of vertices with high noise ceiling (for model1)
voxel_indices_path = "/workspace/model_packed2/ckpts/part1_voxel_indices.pt"
voxel_indices = torch.load(voxel_indices_path)[subject]
model = TowPartModel(model1, model2, voxel_indices)
model = model.cuda().eval()
x = torch.randn(1, 3, 224, 224)
def transform_image(x):
means = [0.485, 0.456, 0.406]
stds = [0.229, 0.224, 0.225]
x = (x - torch.tensor(means).view(1, 3, 1, 1)) / torch.tensor(stds).view(1, 3, 1, 1)
return x
x = transform_image(x)
x = x.cuda()
with torch.no_grad():
out = model(x)
print(out.shape)
# torch.Size([1, 327684]) |