maxmo2009 commited on
Commit
a3690e7
·
verified ·
1 Parent(s): 231f119

Update README to include both shipped checkpoints

Browse files
Files changed (1) hide show
  1. README.md +10 -4
README.md CHANGED
@@ -34,9 +34,10 @@ OmniMorph is a unified framework for 2D/3D multi-modal medical imaging (CT, MRI,
34
  | `Scripts/` | Auxiliary scripts (registration, evaluation) |
35
  | `tests/` | Pytest suite for `OMorpher` and loss functions |
36
  | `bash_*.sh`, `*.slurm` | SLURM submission scripts (CUDA + Intel XPU/Dawn) |
37
- | `Models/all_om_net/000110_all_om_net.pth` | Trained checkpoint (epoch 110, multi-modal `recmulmodmutattnnet`) |
 
38
 
39
- > **Note** Only the final checkpoint (epoch 110) is shipped here. Earlier epochs and the `bert_large_uncased` weights are not bundled download `bert-large-uncased` from the official Hugging Face repo if you need the contrastive text encoder.
40
 
41
  ## Setup
42
 
@@ -85,9 +86,14 @@ python OM_reg_flexres.py -C Config/config_om.yaml
85
  import torch
86
  from Diffusion.networks import get_net
87
 
88
- # Production network (multi-modal recmutattnnet)
89
  net = get_net("recmulmodmutattnnet")
90
- state = torch.load("Models/all_om_net/000110_all_om_net.pth", map_location="cpu")
 
 
 
 
 
91
  net.load_state_dict(state["model"] if "model" in state else state)
92
  net.eval()
93
  ```
 
34
  | `Scripts/` | Auxiliary scripts (registration, evaluation) |
35
  | `tests/` | Pytest suite for `OMorpher` and loss functions |
36
  | `bash_*.sh`, `*.slurm` | SLURM submission scripts (CUDA + Intel XPU/Dawn) |
37
+ | `Models/all_om_net/000110_all_om_net.pth` | Trained checkpoint production multi-modal `recmulmodmutattnnet` (epoch 110, ~3.0 GB) |
38
+ | `Models/all_recmulmodmutattnnet/000010_all_recmulmodmutattnnet.pth` | Earlier `recmulmodmutattnnet` run (epoch 10, ~906 MB) |
39
 
40
+ > **Note** Only the final checkpoint of each training run is shipped intermediate epochs and the `bert_large_uncased` weights are not bundled. Download `bert-large-uncased` from the official Hugging Face repo if you need the contrastive text encoder.
41
 
42
  ## Setup
43
 
 
86
  import torch
87
  from Diffusion.networks import get_net
88
 
89
+ # Production network (multi-modal recmulmodmutattnnet)
90
  net = get_net("recmulmodmutattnnet")
91
+
92
+ # Production checkpoint (epoch 110)
93
+ ckpt_path = "Models/all_om_net/000110_all_om_net.pth"
94
+ # Or earlier run: "Models/all_recmulmodmutattnnet/000010_all_recmulmodmutattnnet.pth"
95
+
96
+ state = torch.load(ckpt_path, map_location="cpu")
97
  net.load_state_dict(state["model"] if "model" in state else state)
98
  net.eval()
99
  ```