tan200224 commited on
Commit
fa6ba40
·
verified ·
1 Parent(s): 0b54db7

Add README.md

Browse files
Files changed (1) hide show
  1. README.md +96 -0
README.md ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Conditional Diffusion Model for Medical Image Generation
2
+
3
+ This repository contains a conditional diffusion model trained to generate 3D medical CT scan images based on segmentation masks. The model uses a U-Net architecture with score-based diffusion for high-quality medical image synthesis.
4
+
5
+ ## Model Architecture
6
+
7
+ - **Base Model**: U-Net with 5-level encoder-decoder
8
+ - **Input**: 4-channel 256x256 CT scan images
9
+ - **Conditioning**: Segmentation masks (4-channel 256x256)
10
+ - **Output**: 4-channel 256x256 generated images
11
+ - **Sampling**: Euler-Maruyama sampler with 250 steps
12
+ - **Training**: Score matching loss with conditional generation
13
+
14
+ ## Model Details
15
+
16
+ - **Training Data**: 3,346 medical CT scan examples
17
+ - **Lambda Parameter**: 25.0 (diffusion coefficient)
18
+ - **Embedding Dimension**: 256
19
+ - **Channels**: [32, 64, 128, 256, 512]
20
+ - **Activation**: SiLU (Swish)
21
+
22
+ ## Usage
23
+
24
+ ### Using the Hugging Face API
25
+
26
+ ```python
27
+ from transformers import AutoModelForImageGeneration
28
+ import torch
29
+
30
+ # Load the model
31
+ model = AutoModelForImageGeneration.from_pretrained("your-username/your-model-name")
32
+
33
+ # Generate images
34
+ conditioning_mask = torch.randn(1, 4, 256, 256) # Your segmentation mask
35
+ generated_image = model.generate(conditioning_mask)
36
+ ```
37
+
38
+ ### Local Usage
39
+
40
+ ```python
41
+ import torch
42
+ from model import UNet, marginal_prob_std, diffusion_coeff, Euler_Maruyama_sampler
43
+
44
+ # Load model
45
+ Lambda = 25.0
46
+ marginal_prob_std_fn = lambda t: marginal_prob_std(t, Lambda=Lambda, device='cuda')
47
+ score_model = UNet(marginal_prob_std=marginal_prob_std_fn)
48
+ score_model.load_state_dict(torch.load("ckpt_3D_v2.pth"))
49
+ score_model.eval()
50
+
51
+ # Generate sample
52
+ conditioning_mask = torch.randn(1, 4, 256, 256)
53
+ samples = Euler_Maruyama_sampler(
54
+ score_model,
55
+ marginal_prob_std_fn,
56
+ lambda t: diffusion_coeff(t, Lambda=Lambda, device='cuda'),
57
+ batch_size=1,
58
+ x_shape=(4, 256, 256),
59
+ num_steps=250,
60
+ device='cuda',
61
+ y=conditioning_mask
62
+ )
63
+ ```
64
+
65
+ ## Training
66
+
67
+ The model was trained for 5000 epochs with:
68
+ - Learning rate: 2e-4 (with decay)
69
+ - Batch size: 1
70
+ - Optimizer: Adam
71
+ - Loss: Score matching loss
72
+
73
+ ## Dataset
74
+
75
+ The model was trained on medical CT scan data with corresponding segmentation masks. The dataset contains 3,346 training examples with 80-20 train/validation split.
76
+
77
+ ## Citation
78
+
79
+ If you use this model in your research, please cite:
80
+
81
+ ```bibtex
82
+ @misc{conditional_diffusion_medical,
83
+ title={Conditional Diffusion Model for Medical Image Generation},
84
+ author={Your Name},
85
+ year={2024},
86
+ url={https://huggingface.co/your-username/your-model-name}
87
+ }
88
+ ```
89
+
90
+ ## License
91
+
92
+ [Add your license here]
93
+
94
+ ## Contact
95
+
96
+ For questions or issues, please open an issue on this repository.