File size: 1,247 Bytes
905cc0d
ec378c3
 
 
905cc0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec378c3
 
 
 
 
 
 
 
 
 
 
 
905cc0d
ec378c3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from model import _load_one_model, TowPartModel, BrainEncodingModel
from config_utils import load_from_yaml
import torch

subject = 'subj01'
cfg_path = "/workspace/model_packed2/config.yaml"
model_path1 = f"/workspace/model_packed2/ckpts/{subject}_part1.pth"
model_path2 = f"/workspace/model_packed2/ckpts/{subject}_part2.pth"
# model1 is for vertices with high noise ceiling (nsdgeneral)
# model2 is for vertices from the rest of the brain
model1: BrainEncodingModel = _load_one_model(model_path1, subject, cfg_path)
model2: BrainEncodingModel = _load_one_model(model_path2, subject, cfg_path)
# voxel_indices is a list of indices of vertices with high noise ceiling (for model1)
voxel_indices_path = "/workspace/model_packed2/ckpts/part1_voxel_indices.pt"
voxel_indices = torch.load(voxel_indices_path)[subject]
model = TowPartModel(model1, model2, voxel_indices)

model = model.cuda().eval()


x = torch.randn(1, 3, 224, 224)
def transform_image(x):
    means = [0.485, 0.456, 0.406]
    stds = [0.229, 0.224, 0.225]
    x = (x - torch.tensor(means).view(1, 3, 1, 1)) / torch.tensor(stds).view(1, 3, 1, 1)
    return x
x = transform_image(x)
x = x.cuda()
    

with torch.no_grad():
    out = model(x)
print(out.shape)
# torch.Size([1, 327684])