import os import zipfile import torch from utils.network import UNet def load_model(model_dir, device): """ Load and initialize the pretrained neural network models required for the OpenMAP-T1 pipeline. This function loads four U-Net–based models from the specified pretrained model directory. Each model is moved to the target device (CPU, CUDA, or MPS) and set to evaluation mode. Models loaded: 1. **CNet (Cropping Network)** — Performs face cropping and brain localization. 2. **SSNet (Skull Stripping Network)** — Removes non-brain tissues from MRI scans. 3. **PNet (Parcellation Network)** — Predicts fine-grained anatomical labels across 142 regions. 4. **HNet (Hemisphere Network)** — Segments the brain into hemispheric masks (left/right/other). Args: opt (argparse.Namespace): Parsed command-line arguments containing the pretrained model directory path (`opt.m`). device (torch.device): Target device on which to load models (e.g., `torch.device('cuda')`). Returns: tuple: A tuple containing four initialized and evaluation-ready models: (cnet, ssnet, pnet, hnet). """ # model_zip_path = os.path.join(model_dir, "model.zip") # with zipfile.ZipFile(model_zip_path, "r") as zip_ref: # zip_ref.extractall(model_dir) # -------------------------- # Load CNet (Cropping Network) # -------------------------- # Input: 3-channel (neighboring slices), Output: 1-channel binary mask cnet = UNet(3, 1) print(os.path.join(model_dir, "model", "CNet", "CNet.pth")) cnet.load_state_dict(torch.load(os.path.join(model_dir, "model", "CNet", "CNet.pth"), weights_only=True)) cnet.to(device) cnet.eval() # ------------------------------ # Load SSNet (Skull Stripping Network) # ------------------------------ # Input: 3-channel (neighboring slices), Output: 1-channel brain mask ssnet = UNet(3, 1) ssnet.load_state_dict(torch.load(os.path.join(model_dir, "model", "SSNet", "SSNet.pth"), weights_only=True)) ssnet.to(device) ssnet.eval() # ----------------------------- # Load PNet (Parcellation Network) # ----------------------------- # Input: 4 channels (multi-modal or augmented context), Output: 142 anatomical regions pnet = UNet(4, 142) pnet.load_state_dict(torch.load(os.path.join(model_dir, "model", "PNet", "PNet.pth"), weights_only=True)) pnet.to(device) pnet.eval() # ----------------------------- # Load HNet (Hemisphere Network) # ----------------------------- # Input: 3 channels, Output: 3-class hemisphere mask (left, right, background) hnet = UNet(3, 3) hnet.load_state_dict(torch.load(os.path.join(model_dir, "model", "HNet", "HNet.pth"), weights_only=True)) hnet.to(device) hnet.eval() # Return all loaded, device-initialized, and evaluation-ready models return cnet, ssnet, pnet, hnet