|
|
--- |
|
|
license: apache-2.0 |
|
|
tags: |
|
|
- pytorch |
|
|
- unet |
|
|
- chest-ct |
|
|
- survival-analysis |
|
|
- time-to-event |
|
|
- model-3d |
|
|
model-index: |
|
|
- name: UNet-TTE |
|
|
results: [] |
|
|
--- |
|
|
|
|
|
|
|
|
# SwinUNETR Checkpoint |
|
|
|
|
|
This is a PyTorch Lightning `.ckpt` checkpoint for a SwinUNETR model trained on chest CT images with TTE objective. |
|
|
|
|
|
## Usage |
|
|
|
|
|
A quickstart script is below. |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from src.networks import SwinUNETRForClassification |
|
|
swin_unetr_params = { |
|
|
"img_size": (224, 224, 224), |
|
|
"in_channels": 1, |
|
|
"out_channels": 2, |
|
|
"feature_size": 48, |
|
|
"drop_rate": 0.0, |
|
|
"attn_drop_rate": 0.0, |
|
|
"dropout_path_rate": 0.0, |
|
|
"use_checkpoint": True, |
|
|
} |
|
|
model = SwinUNETRForClassification( |
|
|
swin_unetr_params=swin_unetr_params, num_classes=2 |
|
|
).to(device) |
|
|
state_dict = torch.load( |
|
|
loadmodel_path, map_location=f"cuda:{torch.cuda.current_device()}" |
|
|
) |
|
|
model.load_state_dict(state_dict) |
|
|
``` |
|
|
|
|
|
For detailed instructions please follow the [README in Github repo](https://github.com/som-shahlab/tte-pretraining/tree/main?tab=readme-ov-file#evaluation). |