tan200224 commited on
Commit
d78c654
·
verified ·
1 Parent(s): 4a5d5c6

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +99 -96
README.md CHANGED
@@ -1,96 +1,99 @@
1
- # Conditional Diffusion Model for Medical Image Generation
2
-
3
- This repository contains a conditional diffusion model trained to generate 3D medical CT scan images based on segmentation masks. The model uses a U-Net architecture with score-based diffusion for high-quality medical image synthesis.
4
-
5
- ## Model Architecture
6
-
7
- - **Base Model**: U-Net with 5-level encoder-decoder
8
- - **Input**: 4-channel 256x256 CT scan images
9
- - **Conditioning**: Segmentation masks (4-channel 256x256)
10
- - **Output**: 4-channel 256x256 generated images
11
- - **Sampling**: Euler-Maruyama sampler with 250 steps
12
- - **Training**: Score matching loss with conditional generation
13
-
14
- ## Model Details
15
-
16
- - **Training Data**: 3,346 medical CT scan examples
17
- - **Lambda Parameter**: 25.0 (diffusion coefficient)
18
- - **Embedding Dimension**: 256
19
- - **Channels**: [32, 64, 128, 256, 512]
20
- - **Activation**: SiLU (Swish)
21
-
22
- ## Usage
23
-
24
- ### Using the Hugging Face API
25
-
26
- ```python
27
- from transformers import AutoModelForImageGeneration
28
- import torch
29
-
30
- # Load the model
31
- model = AutoModelForImageGeneration.from_pretrained("your-username/your-model-name")
32
-
33
- # Generate images
34
- conditioning_mask = torch.randn(1, 4, 256, 256) # Your segmentation mask
35
- generated_image = model.generate(conditioning_mask)
36
- ```
37
-
38
- ### Local Usage
39
-
40
- ```python
41
- import torch
42
- from model import UNet, marginal_prob_std, diffusion_coeff, Euler_Maruyama_sampler
43
-
44
- # Load model
45
- Lambda = 25.0
46
- marginal_prob_std_fn = lambda t: marginal_prob_std(t, Lambda=Lambda, device='cuda')
47
- score_model = UNet(marginal_prob_std=marginal_prob_std_fn)
48
- score_model.load_state_dict(torch.load("ckpt_3D_v2.pth"))
49
- score_model.eval()
50
-
51
- # Generate sample
52
- conditioning_mask = torch.randn(1, 4, 256, 256)
53
- samples = Euler_Maruyama_sampler(
54
- score_model,
55
- marginal_prob_std_fn,
56
- lambda t: diffusion_coeff(t, Lambda=Lambda, device='cuda'),
57
- batch_size=1,
58
- x_shape=(4, 256, 256),
59
- num_steps=250,
60
- device='cuda',
61
- y=conditioning_mask
62
- )
63
- ```
64
-
65
- ## Training
66
-
67
- The model was trained for 5000 epochs with:
68
- - Learning rate: 2e-4 (with decay)
69
- - Batch size: 1
70
- - Optimizer: Adam
71
- - Loss: Score matching loss
72
-
73
- ## Dataset
74
-
75
- The model was trained on medical CT scan data with corresponding segmentation masks. The dataset contains 3,346 training examples with 80-20 train/validation split.
76
-
77
- ## Citation
78
-
79
- If you use this model in your research, please cite:
80
-
81
- ```bibtex
82
- @misc{conditional_diffusion_medical,
83
- title={Conditional Diffusion Model for Medical Image Generation},
84
- author={Your Name},
85
- year={2024},
86
- url={https://huggingface.co/your-username/your-model-name}
87
- }
88
- ```
89
-
90
- ## License
91
-
92
- [Add your license here]
93
-
94
- ## Contact
95
-
96
- For questions or issues, please open an issue on this repository.
 
 
 
 
1
+ ---
2
+ license: mit
3
+ ---
4
+ # Conditional Diffusion Model for Medical Image Generation
5
+
6
+ This repository contains a conditional diffusion model trained to generate 3D medical CT scan images based on segmentation masks. The model uses a U-Net architecture with score-based diffusion for high-quality medical image synthesis.
7
+
8
+ ## Model Architecture
9
+
10
+ - **Base Model**: U-Net with 5-level encoder-decoder
11
+ - **Input**: 4-channel 256x256 CT scan images
12
+ - **Conditioning**: Segmentation masks (4-channel 256x256)
13
+ - **Output**: 4-channel 256x256 generated images
14
+ - **Sampling**: Euler-Maruyama sampler with 250 steps
15
+ - **Training**: Score matching loss with conditional generation
16
+
17
+ ## Model Details
18
+
19
+ - **Training Data**: 3,346 medical CT scan examples
20
+ - **Lambda Parameter**: 25.0 (diffusion coefficient)
21
+ - **Embedding Dimension**: 256
22
+ - **Channels**: [32, 64, 128, 256, 512]
23
+ - **Activation**: SiLU (Swish)
24
+
25
+ ## Usage
26
+
27
+ ### Using the Hugging Face API
28
+
29
+ ```python
30
+ from transformers import AutoModelForImageGeneration
31
+ import torch
32
+
33
+ # Load the model
34
+ model = AutoModelForImageGeneration.from_pretrained("your-username/your-model-name")
35
+
36
+ # Generate images
37
+ conditioning_mask = torch.randn(1, 4, 256, 256) # Your segmentation mask
38
+ generated_image = model.generate(conditioning_mask)
39
+ ```
40
+
41
+ ### Local Usage
42
+
43
+ ```python
44
+ import torch
45
+ from model import UNet, marginal_prob_std, diffusion_coeff, Euler_Maruyama_sampler
46
+
47
+ # Load model
48
+ Lambda = 25.0
49
+ marginal_prob_std_fn = lambda t: marginal_prob_std(t, Lambda=Lambda, device='cuda')
50
+ score_model = UNet(marginal_prob_std=marginal_prob_std_fn)
51
+ score_model.load_state_dict(torch.load("ckpt_3D_v2.pth"))
52
+ score_model.eval()
53
+
54
+ # Generate sample
55
+ conditioning_mask = torch.randn(1, 4, 256, 256)
56
+ samples = Euler_Maruyama_sampler(
57
+ score_model,
58
+ marginal_prob_std_fn,
59
+ lambda t: diffusion_coeff(t, Lambda=Lambda, device='cuda'),
60
+ batch_size=1,
61
+ x_shape=(4, 256, 256),
62
+ num_steps=250,
63
+ device='cuda',
64
+ y=conditioning_mask
65
+ )
66
+ ```
67
+
68
+ ## Training
69
+
70
+ The model was trained for 5000 epochs with:
71
+ - Learning rate: 2e-4 (with decay)
72
+ - Batch size: 1
73
+ - Optimizer: Adam
74
+ - Loss: Score matching loss
75
+
76
+ ## Dataset
77
+
78
+ The model was trained on medical CT scan data with corresponding segmentation masks. The dataset contains 3,346 training examples with 80-20 train/validation split.
79
+
80
+ ## Citation
81
+
82
+ If you use this model in your research, please cite:
83
+
84
+ ```bibtex
85
+ @misc{conditional_diffusion_medical,
86
+ title={Conditional Diffusion Model for Medical Image Generation},
87
+ author={Archie Tan, Scott, Spurlock},
88
+ year={2025},
89
+ url={https://huggingface.co/tan200224}
90
+ }
91
+ ```
92
+
93
+ ## License
94
+
95
+ [Add your license here]
96
+
97
+ ## Contact
98
+
99
+ For questions or issues, please open an issue on this repository.