| | import os.path |
| | import os |
| |
|
| | import monai.networks.nets as nets |
| | import torch |
| |
|
| | from huggingface_hub import hf_hub_download |
| |
|
| | from constants import ROOT_DIR, MODEL_FILENAME, HF_MODEL_REPO_NAME |
| |
|
| | def load_model(): |
| | """ |
| | Load pretrained model |
| | """ |
| |
|
| | model_path = os.path.join(ROOT_DIR, "model", MODEL_FILENAME) |
| |
|
| | |
| | if not os.path.exists(model_path): |
| | hf_hub_download(HF_MODEL_REPO_NAME, MODEL_FILENAME, local_dir=os.path.join(ROOT_DIR, "model")) |
| |
|
| | model = nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=3) |
| | if torch.cuda.is_available(): |
| | checkpoint = torch.load(model_path) |
| | else: |
| | checkpoint = torch.load(model_path, map_location=torch.device("cpu")) |
| | model.load_state_dict(checkpoint) |
| | model.eval() |
| |
|
| | return model |