File size: 2,987 Bytes
03642cc
 
 
 
 
 
 
 
 
 
c9d94d6
03642cc
c9d94d6
 
 
 
 
 
 
 
 
 
 
 
03642cc
 
c9d94d6
 
 
03642cc
c6679f4
 
 
c9d94d6
 
 
 
 
 
 
03642cc
 
 
c9d94d6
 
 
 
 
 
03642cc
 
 
c9d94d6
 
 
 
 
 
 
 
03642cc
c9d94d6
 
 
 
 
 
 
 
03642cc
c9d94d6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import zipfile

import torch

from utils.network import UNet


def load_model(model_dir, device):
    """
    Load and initialize the pretrained neural network models required for the OpenMAP-T1 pipeline.

    This function loads four U-Net–based models from the specified pretrained model directory.
    Each model is moved to the target device (CPU, CUDA, or MPS) and set to evaluation mode.

    Models loaded:
        1. **CNet (Cropping Network)** — Performs face cropping and brain localization.
        2. **SSNet (Skull Stripping Network)** — Removes non-brain tissues from MRI scans.
        3. **PNet (Parcellation Network)** — Predicts fine-grained anatomical labels across 142 regions.
        4. **HNet (Hemisphere Network)** — Segments the brain into hemispheric masks (left/right/other).

    Args:
        opt (argparse.Namespace): Parsed command-line arguments containing the pretrained model directory path (`opt.m`).
        device (torch.device): Target device on which to load models (e.g., `torch.device('cuda')`).

    Returns:
        tuple:
            A tuple containing four initialized and evaluation-ready models:
            (cnet, ssnet, pnet, hnet).
    """
    # model_zip_path = os.path.join(model_dir, "model.zip")
    # with zipfile.ZipFile(model_zip_path, "r") as zip_ref:
    #     zip_ref.extractall(model_dir)
    # --------------------------
    # Load CNet (Cropping Network)
    # --------------------------
    # Input: 3-channel (neighboring slices), Output: 1-channel binary mask
    cnet = UNet(3, 1)
    print(os.path.join(model_dir, "model", "CNet", "CNet.pth"))
    cnet.load_state_dict(torch.load(os.path.join(model_dir, "model", "CNet", "CNet.pth"), weights_only=True))
    cnet.to(device)
    cnet.eval()

    # ------------------------------
    # Load SSNet (Skull Stripping Network)
    # ------------------------------
    # Input: 3-channel (neighboring slices), Output: 1-channel brain mask
    ssnet = UNet(3, 1)
    ssnet.load_state_dict(torch.load(os.path.join(model_dir, "model", "SSNet", "SSNet.pth"), weights_only=True))
    ssnet.to(device)
    ssnet.eval()

    # -----------------------------
    # Load PNet (Parcellation Network)
    # -----------------------------
    # Input: 4 channels (multi-modal or augmented context), Output: 142 anatomical regions
    pnet = UNet(4, 142)
    pnet.load_state_dict(torch.load(os.path.join(model_dir, "model", "PNet", "PNet.pth"), weights_only=True))
    pnet.to(device)
    pnet.eval()

    # -----------------------------
    # Load HNet (Hemisphere Network)
    # -----------------------------
    # Input: 3 channels, Output: 3-class hemisphere mask (left, right, background)
    hnet = UNet(3, 3)
    hnet.load_state_dict(torch.load(os.path.join(model_dir, "model", "HNet", "HNet.pth"), weights_only=True))
    hnet.to(device)
    hnet.eval()

    # Return all loaded, device-initialized, and evaluation-ready models
    return cnet, ssnet, pnet, hnet