Update README to include both shipped checkpoints
Browse files
README.md
CHANGED
|
@@ -34,9 +34,10 @@ OmniMorph is a unified framework for 2D/3D multi-modal medical imaging (CT, MRI,
|
|
| 34 |
| `Scripts/` | Auxiliary scripts (registration, evaluation) |
|
| 35 |
| `tests/` | Pytest suite for `OMorpher` and loss functions |
|
| 36 |
| `bash_*.sh`, `*.slurm` | SLURM submission scripts (CUDA + Intel XPU/Dawn) |
|
| 37 |
-
| `Models/all_om_net/000110_all_om_net.pth` | Trained checkpoint
|
|
|
|
| 38 |
|
| 39 |
-
> **Note** Only the final checkpoint
|
| 40 |
|
| 41 |
## Setup
|
| 42 |
|
|
@@ -85,9 +86,14 @@ python OM_reg_flexres.py -C Config/config_om.yaml
|
|
| 85 |
import torch
|
| 86 |
from Diffusion.networks import get_net
|
| 87 |
|
| 88 |
-
# Production network (multi-modal
|
| 89 |
net = get_net("recmulmodmutattnnet")
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
net.load_state_dict(state["model"] if "model" in state else state)
|
| 92 |
net.eval()
|
| 93 |
```
|
|
|
|
| 34 |
| `Scripts/` | Auxiliary scripts (registration, evaluation) |
|
| 35 |
| `tests/` | Pytest suite for `OMorpher` and loss functions |
|
| 36 |
| `bash_*.sh`, `*.slurm` | SLURM submission scripts (CUDA + Intel XPU/Dawn) |
|
| 37 |
+
| `Models/all_om_net/000110_all_om_net.pth` | Trained checkpoint — production multi-modal `recmulmodmutattnnet` (epoch 110, ~3.0 GB) |
|
| 38 |
+
| `Models/all_recmulmodmutattnnet/000010_all_recmulmodmutattnnet.pth` | Earlier `recmulmodmutattnnet` run (epoch 10, ~906 MB) |
|
| 39 |
|
| 40 |
+
> **Note** Only the final checkpoint of each training run is shipped — intermediate epochs and the `bert_large_uncased` weights are not bundled. Download `bert-large-uncased` from the official Hugging Face repo if you need the contrastive text encoder.
|
| 41 |
|
| 42 |
## Setup
|
| 43 |
|
|
|
|
| 86 |
import torch
|
| 87 |
from Diffusion.networks import get_net
|
| 88 |
|
| 89 |
+
# Production network (multi-modal recmulmodmutattnnet)
|
| 90 |
net = get_net("recmulmodmutattnnet")
|
| 91 |
+
|
| 92 |
+
# Production checkpoint (epoch 110)
|
| 93 |
+
ckpt_path = "Models/all_om_net/000110_all_om_net.pth"
|
| 94 |
+
# Or earlier run: "Models/all_recmulmodmutattnnet/000010_all_recmulmodmutattnnet.pth"
|
| 95 |
+
|
| 96 |
+
state = torch.load(ckpt_path, map_location="cpu")
|
| 97 |
net.load_state_dict(state["model"] if "model" in state else state)
|
| 98 |
net.eval()
|
| 99 |
```
|