File size: 1,935 Bytes
714aa62
 
 
 
 
 
 
 
 
 
3798dc6
 
 
 
 
714aa62
3798dc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
---
license: mit
pipeline_tag: image-to-image
tags:
- pytorch
- medical
- image-generation
- conditional-image-generation
---

# EHRXDiff
Model card for our paper: [Towards Predicting Temporal Changes in a Patient's Chest X-ray Images based on Electronic Health Records](https://arxiv.org/abs/2409.07012). 
We provide two versions of the **EHRXDiff** model:
* **EHRXDiff** – trained without the null-based augmentation technique
* **EHRXDiff<sub>w_null</sub>** – trained with the null-based augmentation technique.

This card describes the **EHRXDiff** model.
For implementation details, please refer to the [EHRXDiff repository](https://github.com/dek924/EHRXDiff).


## Installation
First, clone the repository and install the required packages:
```
git clone https://github.com/dek924/EHRXDiff.git

pip install "pip<24.1"
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
pip install -r requirements.txt
```

## Loading the model
You can load the model directly in Python:
```python
from cheff.ldm.models.diffusion.ddpm_tab import EHRXDiff

model = EHRXDiff.from_pretrained("dek924/ehrxdiff")
model.eval()
```
Alternatively, you can download the weights via the Hugging Face Hub:

```python
from huggingface_hub import hf_hub_download

wt_path = hf_hub_download("dek924/ehrxdiff", "pytorch_model.bin")
```
and then run the evaluation script included in our github repository (`scripts/eval.py`):
```
python scripts/eval.py \
    --sdm_path=${CHECKPOINT_PATH}/pytorch_model.bin \
    --save_dir=${CHECKPOINT_PATH}/images/seed${RAND_SEED} \
    --img_meta_dir=${IMG_META_DIR} \   # Directory containing metadata for MIMIC-CXR-JPG
    --img_root_dir=${IMG_ROOT_DIR} \   # Directory containing preprocessed images
    --tab_root_dir=${TAB_ROOT_DIR} \   # Directory containing tabular data
    --seed=${RAND_SEED} \
    --batch_size=${BATCHSIZE}
```