Spaces:
Sleeping
Sleeping
File size: 2,987 Bytes
03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c6679f4 c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 03642cc c9d94d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import os
import zipfile
import torch
from utils.network import UNet
def load_model(model_dir, device):
"""
Load and initialize the pretrained neural network models required for the OpenMAP-T1 pipeline.
This function loads four U-Net–based models from the specified pretrained model directory.
Each model is moved to the target device (CPU, CUDA, or MPS) and set to evaluation mode.
Models loaded:
1. **CNet (Cropping Network)** — Performs face cropping and brain localization.
2. **SSNet (Skull Stripping Network)** — Removes non-brain tissues from MRI scans.
3. **PNet (Parcellation Network)** — Predicts fine-grained anatomical labels across 142 regions.
4. **HNet (Hemisphere Network)** — Segments the brain into hemispheric masks (left/right/other).
Args:
opt (argparse.Namespace): Parsed command-line arguments containing the pretrained model directory path (`opt.m`).
device (torch.device): Target device on which to load models (e.g., `torch.device('cuda')`).
Returns:
tuple:
A tuple containing four initialized and evaluation-ready models:
(cnet, ssnet, pnet, hnet).
"""
# model_zip_path = os.path.join(model_dir, "model.zip")
# with zipfile.ZipFile(model_zip_path, "r") as zip_ref:
# zip_ref.extractall(model_dir)
# --------------------------
# Load CNet (Cropping Network)
# --------------------------
# Input: 3-channel (neighboring slices), Output: 1-channel binary mask
cnet = UNet(3, 1)
print(os.path.join(model_dir, "model", "CNet", "CNet.pth"))
cnet.load_state_dict(torch.load(os.path.join(model_dir, "model", "CNet", "CNet.pth"), weights_only=True))
cnet.to(device)
cnet.eval()
# ------------------------------
# Load SSNet (Skull Stripping Network)
# ------------------------------
# Input: 3-channel (neighboring slices), Output: 1-channel brain mask
ssnet = UNet(3, 1)
ssnet.load_state_dict(torch.load(os.path.join(model_dir, "model", "SSNet", "SSNet.pth"), weights_only=True))
ssnet.to(device)
ssnet.eval()
# -----------------------------
# Load PNet (Parcellation Network)
# -----------------------------
# Input: 4 channels (multi-modal or augmented context), Output: 142 anatomical regions
pnet = UNet(4, 142)
pnet.load_state_dict(torch.load(os.path.join(model_dir, "model", "PNet", "PNet.pth"), weights_only=True))
pnet.to(device)
pnet.eval()
# -----------------------------
# Load HNet (Hemisphere Network)
# -----------------------------
# Input: 3 channels, Output: 3-class hemisphere mask (left, right, background)
hnet = UNet(3, 3)
hnet.load_state_dict(torch.load(os.path.join(model_dir, "model", "HNet", "HNet.pth"), weights_only=True))
hnet.to(device)
hnet.eval()
# Return all loaded, device-initialized, and evaluation-ready models
return cnet, ssnet, pnet, hnet
|