BiliSakura commited on
Commit
01176da
·
verified ·
1 Parent(s): 1af9de2

Add files using upload-large-folder tool

Browse files
Files changed (31) hide show
  1. README.md +82 -0
  2. cd_head/cdd-50-100-400-650/config.json +25 -0
  3. cd_head/cdd-50-100-400-650/diffusion_pytorch_model.safetensors +3 -0
  4. cd_head/cdd-50-100-400/config.json +24 -0
  5. cd_head/cdd-50-100-400/diffusion_pytorch_model.safetensors +3 -0
  6. cd_head/cdd-50-100/config.json +23 -0
  7. cd_head/cdd-50-100/diffusion_pytorch_model.safetensors +3 -0
  8. cd_head/dsifn-50-100-400-650/config.json +25 -0
  9. cd_head/dsifn-50-100-400-650/diffusion_pytorch_model.safetensors +3 -0
  10. cd_head/dsifn-50-100-400/config.json +24 -0
  11. cd_head/dsifn-50-100-400/diffusion_pytorch_model.safetensors +3 -0
  12. cd_head/dsifn-50-100/config.json +23 -0
  13. cd_head/dsifn-50-100/diffusion_pytorch_model.safetensors +3 -0
  14. cd_head/levir-50-100-400-650/config.json +25 -0
  15. cd_head/levir-50-100-400-650/diffusion_pytorch_model.safetensors +3 -0
  16. cd_head/levir-50-100-400/config.json +24 -0
  17. cd_head/levir-50-100-400/diffusion_pytorch_model.safetensors +3 -0
  18. cd_head/levir-50-100/config.json +23 -0
  19. cd_head/levir-50-100/diffusion_pytorch_model.safetensors +3 -0
  20. cd_head/whu-50-100-400-650/config.json +25 -0
  21. cd_head/whu-50-100-400-650/diffusion_pytorch_model.safetensors +3 -0
  22. cd_head/whu-50-100-400/config.json +24 -0
  23. cd_head/whu-50-100-400/diffusion_pytorch_model.safetensors +3 -0
  24. cd_head/whu-50-100/config.json +23 -0
  25. cd_head/whu-50-100/diffusion_pytorch_model.safetensors +3 -0
  26. model_index.json +12 -0
  27. pipeline.py +518 -0
  28. scheduler/scheduler_config.json +19 -0
  29. unet/config.json +23 -0
  30. unet/diffusion_pytorch_model.safetensors +3 -0
  31. unet/unet.py +495 -0
README.md ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - diffusers
5
+ - ddpm-cd
6
+ - change-detection
7
+ - remote-sensing
8
+ ---
9
+
10
+ # BiliSakura/ddpm-cd
11
+
12
+ **Consolidated DDPM-CD change detection** — Single repo with shared UNet backbone and multiple cd_head variants (trained on different datasets and timestep configs).
13
+
14
+ ## Model Structure
15
+
16
+ - **Backbone**: Shared SR3-style UNet (same across all variants)
17
+ - **cd_head**: Dataset-specific change detection heads in `cd_head/{variant}/`
18
+
19
+ ### Available cd_head Variants
20
+
21
+ | Variant | Dataset | Timesteps | Path |
22
+ |---------|---------|-----------|------|
23
+ | cdd-50-100 | CDD | [50, 100] | `cd_head/cdd-50-100/` |
24
+ | cdd-50-100-400 | CDD | [50, 100, 400] | `cd_head/cdd-50-100-400/` |
25
+ | cdd-50-100-400-650 | CDD | [50, 100, 400, 650] | `cd_head/cdd-50-100-400-650/` |
26
+ | dsifn-50-100 | DSIFN | [50, 100] | `cd_head/dsifn-50-100/` |
27
+ | dsifn-50-100-400 | DSIFN | [50, 100, 400] | `cd_head/dsifn-50-100-400/` |
28
+ | dsifn-50-100-400-650 | DSIFN | [50, 100, 400, 650] | `cd_head/dsifn-50-100-400-650/` |
29
+ | levir-50-100 | LEVIR | [50, 100] | `cd_head/levir-50-100/` |
30
+ | levir-50-100-400 | LEVIR | [50, 100, 400] | `cd_head/levir-50-100-400/` |
31
+ | levir-50-100-400-650 | LEVIR | [50, 100, 400, 650] | `cd_head/levir-50-100-400-650/` |
32
+ | whu-50-100 | WHU | [50, 100] | `cd_head/whu-50-100/` |
33
+ | whu-50-100-400 | WHU | [50, 100, 400] | `cd_head/whu-50-100-400/` |
34
+ | whu-50-100-400-650 | WHU | [50, 100, 400, 650] | `cd_head/whu-50-100-400-650/` |
35
+
36
+ ## Usage
37
+
38
+ Load with explicit `custom_pipeline` (pipeline.py is in the repo, use relative path) and `cd_head_subfolder`:
39
+
40
+ ```python
41
+ from diffusers import DiffusionPipeline
42
+
43
+ pipe = DiffusionPipeline.from_pretrained(
44
+ "BiliSakura/ddpm-cd",
45
+ custom_pipeline="pipeline",
46
+ trust_remote_code=True,
47
+ cd_head_subfolder="levir-50-100",
48
+ )
49
+
50
+ # Images in [-1, 1], shape (B, 3, H, W)
51
+ change_map = pipe(image_A, image_B, timesteps=[50, 100])
52
+ pred = change_map.argmax(1) # (B, H, W), 0=no-change, 1=change
53
+ ```
54
+
55
+ **Important**: Pass the same `timesteps` used during training for each variant (see table above).
56
+
57
+ ### Switching cd_head at Runtime
58
+
59
+ ```python
60
+ pipe = DiffusionPipeline.from_pretrained(
61
+ "BiliSakura/ddpm-cd",
62
+ custom_pipeline="pipeline",
63
+ trust_remote_code=True,
64
+ cd_head_subfolder="levir-50-100",
65
+ )
66
+ # Load different cd_head
67
+ pipe.load_cd_head(subfolder="whu-50-100-400")
68
+ change_map = pipe(image_A, image_B, timesteps=[50, 100, 400])
69
+ ```
70
+
71
+ ## Citation
72
+
73
+ ```bibtex
74
+ @misc{bandara2024ddpmcdv3,
75
+ title={DDPM-CD: Denoising Diffusion Probabilistic Models as Feature Extractors for Change Detection},
76
+ author={Wele Gedara Chaminda Bandara and Nithin Gopalakrishnan Nair and Vishal M. Patel},
77
+ year={2024},
78
+ eprint={2206.11892},
79
+ archivePrefix={arXiv},
80
+ primaryClass={cs.CV},
81
+ }
82
+ ```
cd_head/cdd-50-100-400-650/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "feat_scales": [
3
+ 2,
4
+ 5,
5
+ 8,
6
+ 11,
7
+ 14
8
+ ],
9
+ "inner_channel": 128,
10
+ "channel_multiplier": [
11
+ 1,
12
+ 2,
13
+ 4,
14
+ 8,
15
+ 8
16
+ ],
17
+ "out_channels": 2,
18
+ "img_size": 256,
19
+ "time_steps": [
20
+ 50,
21
+ 100,
22
+ 400,
23
+ 650
24
+ ]
25
+ }
cd_head/cdd-50-100-400-650/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:088be206470c3e764d1e766754319d243d7a4e1c078fa21edca0a3658de3834e
3
+ size 195390880
cd_head/cdd-50-100-400/config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "feat_scales": [
3
+ 2,
4
+ 5,
5
+ 8,
6
+ 11,
7
+ 14
8
+ ],
9
+ "inner_channel": 128,
10
+ "channel_multiplier": [
11
+ 1,
12
+ 2,
13
+ 4,
14
+ 8,
15
+ 8
16
+ ],
17
+ "out_channels": 2,
18
+ "img_size": 256,
19
+ "time_steps": [
20
+ 50,
21
+ 100,
22
+ 400
23
+ ]
24
+ }
cd_head/cdd-50-100-400/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd3786cca5e126438d6d6676504c260d3d058e27afc8d43dda0627199eb4dc9e
3
+ size 185626008
cd_head/cdd-50-100/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "feat_scales": [
3
+ 2,
4
+ 5,
5
+ 8,
6
+ 11,
7
+ 14
8
+ ],
9
+ "inner_channel": 128,
10
+ "channel_multiplier": [
11
+ 1,
12
+ 2,
13
+ 4,
14
+ 8,
15
+ 8
16
+ ],
17
+ "out_channels": 2,
18
+ "img_size": 256,
19
+ "time_steps": [
20
+ 50,
21
+ 100
22
+ ]
23
+ }
cd_head/cdd-50-100/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf58a46ac8449df67802ec45c5346ae183766951c54328f43b1d018f794f7ed2
3
+ size 175861136
cd_head/dsifn-50-100-400-650/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "feat_scales": [
3
+ 2,
4
+ 5,
5
+ 8,
6
+ 11,
7
+ 14
8
+ ],
9
+ "inner_channel": 128,
10
+ "channel_multiplier": [
11
+ 1,
12
+ 2,
13
+ 4,
14
+ 8,
15
+ 8
16
+ ],
17
+ "out_channels": 2,
18
+ "img_size": 256,
19
+ "time_steps": [
20
+ 50,
21
+ 100,
22
+ 400,
23
+ 650
24
+ ]
25
+ }
cd_head/dsifn-50-100-400-650/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c35c9945b1dab746cc7e5eb1cb940dfb1829783ca4dbb033bf44bf27e7a6cc87
3
+ size 195390880
cd_head/dsifn-50-100-400/config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "feat_scales": [
3
+ 2,
4
+ 5,
5
+ 8,
6
+ 11,
7
+ 14
8
+ ],
9
+ "inner_channel": 128,
10
+ "channel_multiplier": [
11
+ 1,
12
+ 2,
13
+ 4,
14
+ 8,
15
+ 8
16
+ ],
17
+ "out_channels": 2,
18
+ "img_size": 256,
19
+ "time_steps": [
20
+ 50,
21
+ 100,
22
+ 400
23
+ ]
24
+ }
cd_head/dsifn-50-100-400/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df8e5ea2901fb902568aef6beaeff4e77516d83aef6c24c2a7e5455e393e13a7
3
+ size 185626008
cd_head/dsifn-50-100/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "feat_scales": [
3
+ 2,
4
+ 5,
5
+ 8,
6
+ 11,
7
+ 14
8
+ ],
9
+ "inner_channel": 128,
10
+ "channel_multiplier": [
11
+ 1,
12
+ 2,
13
+ 4,
14
+ 8,
15
+ 8
16
+ ],
17
+ "out_channels": 2,
18
+ "img_size": 256,
19
+ "time_steps": [
20
+ 50,
21
+ 100
22
+ ]
23
+ }
cd_head/dsifn-50-100/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b2214ccc1aa96b62bff1e810a64d81225495dd0b6500e26c11cc41c388d43b3
3
+ size 175861136
cd_head/levir-50-100-400-650/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "feat_scales": [
3
+ 2,
4
+ 5,
5
+ 8,
6
+ 11,
7
+ 14
8
+ ],
9
+ "inner_channel": 128,
10
+ "channel_multiplier": [
11
+ 1,
12
+ 2,
13
+ 4,
14
+ 8,
15
+ 8
16
+ ],
17
+ "out_channels": 2,
18
+ "img_size": 256,
19
+ "time_steps": [
20
+ 50,
21
+ 100,
22
+ 400,
23
+ 650
24
+ ]
25
+ }
cd_head/levir-50-100-400-650/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7712ce94b7c16adb675590a0b0909555de650b9b7dd4481e2f25557f643594fa
3
+ size 195390880
cd_head/levir-50-100-400/config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "feat_scales": [
3
+ 2,
4
+ 5,
5
+ 8,
6
+ 11,
7
+ 14
8
+ ],
9
+ "inner_channel": 128,
10
+ "channel_multiplier": [
11
+ 1,
12
+ 2,
13
+ 4,
14
+ 8,
15
+ 8
16
+ ],
17
+ "out_channels": 2,
18
+ "img_size": 256,
19
+ "time_steps": [
20
+ 50,
21
+ 100,
22
+ 400
23
+ ]
24
+ }
cd_head/levir-50-100-400/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:395683768fe2a31d2d10b3231c118a20e4b7107247fe7808d125c29635c8a419
3
+ size 185626008
cd_head/levir-50-100/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "feat_scales": [
3
+ 2,
4
+ 5,
5
+ 8,
6
+ 11,
7
+ 14
8
+ ],
9
+ "inner_channel": 128,
10
+ "channel_multiplier": [
11
+ 1,
12
+ 2,
13
+ 4,
14
+ 8,
15
+ 8
16
+ ],
17
+ "out_channels": 2,
18
+ "img_size": 256,
19
+ "time_steps": [
20
+ 50,
21
+ 100
22
+ ]
23
+ }
cd_head/levir-50-100/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5624b511362aab2ae735fc2e670c4e4dd3d2405e51c3d3e608d3027a4288ad85
3
+ size 175861136
cd_head/whu-50-100-400-650/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "feat_scales": [
3
+ 2,
4
+ 5,
5
+ 8,
6
+ 11,
7
+ 14
8
+ ],
9
+ "inner_channel": 128,
10
+ "channel_multiplier": [
11
+ 1,
12
+ 2,
13
+ 4,
14
+ 8,
15
+ 8
16
+ ],
17
+ "out_channels": 2,
18
+ "img_size": 256,
19
+ "time_steps": [
20
+ 50,
21
+ 100,
22
+ 400,
23
+ 650
24
+ ]
25
+ }
cd_head/whu-50-100-400-650/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36d6d97b2839c6cb19b66a5aa709d43bb6fd8cf508341d7f0cc4d041ac7e8d76
3
+ size 195390880
cd_head/whu-50-100-400/config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "feat_scales": [
3
+ 2,
4
+ 5,
5
+ 8,
6
+ 11,
7
+ 14
8
+ ],
9
+ "inner_channel": 128,
10
+ "channel_multiplier": [
11
+ 1,
12
+ 2,
13
+ 4,
14
+ 8,
15
+ 8
16
+ ],
17
+ "out_channels": 2,
18
+ "img_size": 256,
19
+ "time_steps": [
20
+ 50,
21
+ 100,
22
+ 400
23
+ ]
24
+ }
cd_head/whu-50-100-400/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29d7151abaed8eb0280fb92beb9cd5ef73329fb269f7b77de0b612dcdbc8cd21
3
+ size 185626008
cd_head/whu-50-100/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "feat_scales": [
3
+ 2,
4
+ 5,
5
+ 8,
6
+ 11,
7
+ 14
8
+ ],
9
+ "inner_channel": 128,
10
+ "channel_multiplier": [
11
+ 1,
12
+ 2,
13
+ 4,
14
+ 8,
15
+ 8
16
+ ],
17
+ "out_channels": 2,
18
+ "img_size": 256,
19
+ "time_steps": [
20
+ 50,
21
+ 100
22
+ ]
23
+ }
cd_head/whu-50-100/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de3fe09d7c0df23f0d790e8d2d1fcbf080c8f835e044e70df7d40f38a7112891
3
+ size 175861136
model_index.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": ["pipeline", "DDPMCDPipeline"],
3
+ "_diffusers_version": "0.36.0",
4
+ "scheduler": [
5
+ "diffusers",
6
+ "DDPMScheduler"
7
+ ],
8
+ "unet": [
9
+ "unet",
10
+ "UNet"
11
+ ]
12
+ }
pipeline.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DDPMCDPipeline for change detection.
3
+ pipeline.py is in the repo — use custom_pipeline="pipeline" (relative path).
4
+
5
+ Usage::
6
+
7
+ from diffusers import DiffusionPipeline
8
+
9
+ pipe = DiffusionPipeline.from_pretrained(
10
+ "BiliSakura/ddpm-cd",
11
+ custom_pipeline="pipeline",
12
+ trust_remote_code=True,
13
+ cd_head_subfolder="levir-50-100",
14
+ )
15
+ change_map = pipe(image_A, image_B, timesteps=[50, 100])
16
+ """
17
+
18
+ import json
19
+ import math
20
+ import os
21
+ from inspect import isfunction
22
+
23
+ import numpy as np
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+ from diffusers import DDPMScheduler
28
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
29
+ from diffusers.models.modeling_utils import ModelMixin # ModelMixin subclasses nn.Module
30
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
31
+ from tqdm.auto import tqdm
32
+
33
+
34
+ # ===========================================================================
35
+ # UNet (SR3-style) - all components inlined
36
+ # ===========================================================================
37
+
38
+ def _exists(x):
39
+ return x is not None
40
+
41
+
42
+ def _default(val, d):
43
+ if _exists(val):
44
+ return val
45
+ return d() if isfunction(d) else d
46
+
47
+
48
+ class PositionalEncoding(nn.Module):
49
+ def __init__(self, dim):
50
+ super().__init__()
51
+ self.dim = dim
52
+
53
+ def forward(self, noise_level):
54
+ count = self.dim // 2
55
+ step = torch.arange(count, dtype=noise_level.dtype, device=noise_level.device) / count
56
+ encoding = noise_level.unsqueeze(1) * torch.exp(-math.log(1e4) * step.unsqueeze(0))
57
+ return torch.cat([torch.sin(encoding), torch.cos(encoding)], dim=-1)
58
+
59
+
60
+ class FeatureWiseAffine(nn.Module):
61
+ def __init__(self, in_channels, out_channels, use_affine_level=False):
62
+ super().__init__()
63
+ self.use_affine_level = use_affine_level
64
+ self.noise_func = nn.Sequential(nn.Linear(in_channels, out_channels * (1 + self.use_affine_level)))
65
+
66
+ def forward(self, x, noise_embed):
67
+ batch = x.shape[0]
68
+ if self.use_affine_level:
69
+ gamma, beta = self.noise_func(noise_embed).view(batch, -1, 1, 1).chunk(2, dim=1)
70
+ x = (1 + gamma) * x + beta
71
+ else:
72
+ x = x + self.noise_func(noise_embed).view(batch, -1, 1, 1)
73
+ return x
74
+
75
+
76
+ class Swish(nn.Module):
77
+ def forward(self, x):
78
+ return x * torch.sigmoid(x)
79
+
80
+
81
+ class Upsample(nn.Module):
82
+ def __init__(self, dim):
83
+ super().__init__()
84
+ self.up = nn.Upsample(scale_factor=2, mode="nearest")
85
+ self.conv = nn.Conv2d(dim, dim, 3, padding=1)
86
+
87
+ def forward(self, x):
88
+ return self.conv(self.up(x))
89
+
90
+
91
+ class Downsample(nn.Module):
92
+ def __init__(self, dim):
93
+ super().__init__()
94
+ self.conv = nn.Conv2d(dim, dim, 3, 2, 1)
95
+
96
+ def forward(self, x):
97
+ return self.conv(x)
98
+
99
+
100
+ class Block(nn.Module):
101
+ def __init__(self, dim, dim_out, groups=32, dropout=0):
102
+ super().__init__()
103
+ self.block = nn.Sequential(
104
+ nn.GroupNorm(groups, dim),
105
+ Swish(),
106
+ nn.Dropout(dropout) if dropout != 0 else nn.Identity(),
107
+ nn.Conv2d(dim, dim_out, 3, padding=1),
108
+ )
109
+
110
+ def forward(self, x):
111
+ return self.block(x)
112
+
113
+
114
+ class ResnetBlock(nn.Module):
115
+ def __init__(self, dim, dim_out, noise_level_emb_dim=None, dropout=0, use_affine_level=False, norm_groups=32):
116
+ super().__init__()
117
+ self.noise_func = FeatureWiseAffine(noise_level_emb_dim, dim_out, use_affine_level)
118
+ self.block1 = Block(dim, dim_out, groups=norm_groups)
119
+ self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout)
120
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
121
+
122
+ def forward(self, x, time_emb):
123
+ h = self.block1(x)
124
+ h = self.noise_func(h, time_emb)
125
+ h = self.block2(h)
126
+ return h + self.res_conv(x)
127
+
128
+
129
+ class SelfAttention(nn.Module):
130
+ def __init__(self, in_channel, n_head=1, norm_groups=32):
131
+ super().__init__()
132
+ self.n_head = n_head
133
+ self.norm = nn.GroupNorm(norm_groups, in_channel)
134
+ self.qkv = nn.Conv2d(in_channel, in_channel * 3, 1, bias=False)
135
+ self.out = nn.Conv2d(in_channel, in_channel, 1)
136
+
137
+ def forward(self, input):
138
+ batch, channel, height, width = input.shape
139
+ n_head, head_dim = self.n_head, channel // self.n_head
140
+ norm = self.norm(input)
141
+ qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, height, width)
142
+ query, key, value = qkv.chunk(3, dim=2)
143
+ attn = torch.einsum("bnchw, bncyx -> bnhwyx", query, key).contiguous() / math.sqrt(channel)
144
+ attn = torch.softmax(attn.view(batch, n_head, height, width, -1), -1)
145
+ attn = attn.view(batch, n_head, height, width, height, width)
146
+ out = torch.einsum("bnhwyx, bncyx -> bnchw", attn, value).contiguous()
147
+ return self.out(out.view(batch, channel, height, width)) + input
148
+
149
+
150
+ class ResnetBlocWithAttn(nn.Module):
151
+ def __init__(self, dim, dim_out, *, noise_level_emb_dim=None, norm_groups=32, dropout=0, with_attn=False):
152
+ super().__init__()
153
+ self.with_attn = with_attn
154
+ self.res_block = ResnetBlock(dim, dim_out, noise_level_emb_dim, norm_groups=norm_groups, dropout=dropout)
155
+ self.attn = SelfAttention(dim_out, norm_groups=norm_groups) if with_attn else None
156
+
157
+ def forward(self, x, time_emb):
158
+ x = self.res_block(x, time_emb)
159
+ if self.with_attn:
160
+ x = self.attn(x)
161
+ return x
162
+
163
+
164
+ class UNet(ModelMixin, ConfigMixin):
165
+ """SR3-style UNet with noise-level conditioning. Supports feat_need=True for intermediate features."""
166
+
167
+ @register_to_config
168
+ def __init__(
169
+ self,
170
+ in_channel=6,
171
+ out_channel=3,
172
+ inner_channel=32,
173
+ norm_groups=32,
174
+ channel_mults=(1, 2, 4, 8, 8),
175
+ attn_res=(8,),
176
+ res_blocks=3,
177
+ dropout=0,
178
+ with_noise_level_emb=True,
179
+ image_size=128,
180
+ ):
181
+ super().__init__()
182
+ noise_level_channel = inner_channel if with_noise_level_emb else None
183
+ self.noise_level_mlp = (
184
+ nn.Sequential(
185
+ PositionalEncoding(inner_channel),
186
+ nn.Linear(inner_channel, inner_channel * 4),
187
+ Swish(),
188
+ nn.Linear(inner_channel * 4, inner_channel),
189
+ )
190
+ if with_noise_level_emb
191
+ else None
192
+ )
193
+
194
+ num_mults = len(channel_mults)
195
+ pre_channel, feat_channels, now_res = inner_channel, [inner_channel], image_size
196
+ self.init_conv = nn.Conv2d(in_channel, inner_channel, 3, padding=1)
197
+
198
+ downs = []
199
+ for ind in range(num_mults):
200
+ use_attn = now_res in attn_res
201
+ channel_mult = inner_channel * channel_mults[ind]
202
+ for _ in range(res_blocks):
203
+ downs.append(
204
+ ResnetBlocWithAttn(
205
+ pre_channel, channel_mult,
206
+ noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups,
207
+ dropout=dropout, with_attn=use_attn,
208
+ )
209
+ )
210
+ feat_channels.append(channel_mult)
211
+ pre_channel = channel_mult
212
+ if ind < num_mults - 1:
213
+ downs.append(Downsample(pre_channel))
214
+ feat_channels.append(pre_channel)
215
+ now_res = now_res // 2
216
+ self.downs = nn.ModuleList(downs)
217
+
218
+ self.mid = nn.ModuleList([
219
+ ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel,
220
+ norm_groups=norm_groups, dropout=dropout, with_attn=True),
221
+ ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel,
222
+ norm_groups=norm_groups, dropout=dropout, with_attn=False),
223
+ ])
224
+
225
+ ups = []
226
+ for ind in reversed(range(num_mults)):
227
+ use_attn = now_res in attn_res
228
+ channel_mult = inner_channel * channel_mults[ind]
229
+ for _ in range(res_blocks + 1):
230
+ ups.append(
231
+ ResnetBlocWithAttn(
232
+ pre_channel + feat_channels.pop(), channel_mult,
233
+ noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups,
234
+ dropout=dropout, with_attn=use_attn,
235
+ )
236
+ )
237
+ pre_channel = channel_mult
238
+ if ind > 0:
239
+ ups.append(Upsample(pre_channel))
240
+ now_res = now_res * 2
241
+ self.ups = nn.ModuleList(ups)
242
+ self.final_conv = Block(pre_channel, _default(out_channel, lambda: in_channel), groups=norm_groups)
243
+
244
+ def forward(self, x, time, feat_need=False):
245
+ t = self.noise_level_mlp(time) if _exists(self.noise_level_mlp) else None
246
+ x = self.init_conv(x)
247
+ feats = [x]
248
+ for layer in self.downs:
249
+ x = layer(x, t) if isinstance(layer, ResnetBlocWithAttn) else layer(x)
250
+ feats.append(x)
251
+ fe = feats.copy() if feat_need else None
252
+ for layer in self.mid:
253
+ x = layer(x, t) if isinstance(layer, ResnetBlocWithAttn) else layer(x)
254
+ fd = [] if feat_need else None
255
+ for layer in self.ups:
256
+ if isinstance(layer, ResnetBlocWithAttn):
257
+ x = layer(torch.cat((x, feats.pop()), dim=1), t)
258
+ if feat_need:
259
+ fd.append(x)
260
+ else:
261
+ x = layer(x)
262
+ x = self.final_conv(x)
263
+ return (fe, list(reversed(fd))) if feat_need else x
264
+
265
+
266
+ # ===========================================================================
267
+ # Change detection head
268
+ # ===========================================================================
269
+
270
+ class ChannelSELayer(nn.Module):
271
+ def __init__(self, num_channels, reduction_ratio=2):
272
+ super().__init__()
273
+ reduced = num_channels // reduction_ratio
274
+ self.fc1 = nn.Linear(num_channels, reduced, bias=True)
275
+ self.fc2 = nn.Linear(reduced, num_channels, bias=True)
276
+ self.relu, self.sigmoid = nn.ReLU(), nn.Sigmoid()
277
+
278
+ def forward(self, x):
279
+ b, c, _, _ = x.size()
280
+ s = x.view(b, c, -1).mean(dim=2)
281
+ s = self.sigmoid(self.fc2(self.relu(self.fc1(s)))).view(b, c, 1, 1)
282
+ return x * s
283
+
284
+
285
+ class SpatialSELayer(nn.Module):
286
+ def __init__(self, num_channels):
287
+ super().__init__()
288
+ self.conv = nn.Conv2d(num_channels, 1, 1)
289
+ self.sigmoid = nn.Sigmoid()
290
+
291
+ def forward(self, x, weights=None):
292
+ b, c, h, w = x.size()
293
+ out = F.conv2d(x, weights.view(1, c, 1, 1)) if weights is not None else self.conv(x)
294
+ return x * self.sigmoid(out).view(b, 1, h, w)
295
+
296
+
297
+ class ChannelSpatialSELayer(nn.Module):
298
+ def __init__(self, num_channels, reduction_ratio=2):
299
+ super().__init__()
300
+ self.cSE = ChannelSELayer(num_channels, reduction_ratio)
301
+ self.sSE = SpatialSELayer(num_channels)
302
+
303
+ def forward(self, x):
304
+ return self.cSE(x) + self.sSE(x)
305
+
306
+
307
+ def _get_in_channels(feat_scales, inner_channel, channel_multiplier):
308
+ m, cm = inner_channel, channel_multiplier
309
+ r = 0
310
+ for s in feat_scales:
311
+ if s < 3: r += m * cm[0]
312
+ elif s < 6: r += m * cm[1]
313
+ elif s < 9: r += m * cm[2]
314
+ elif s < 12: r += m * cm[3]
315
+ elif s < 15: r += m * cm[4]
316
+ else: raise ValueError("feat_scales 0<=s<=14")
317
+ return r
318
+
319
+
320
+ class AttentionBlock(nn.Module):
321
+ def __init__(self, dim, dim_out):
322
+ super().__init__()
323
+ self.block = nn.Sequential(
324
+ nn.Conv2d(dim, dim_out, 3, padding=1),
325
+ nn.ReLU(),
326
+ ChannelSpatialSELayer(dim_out, 2),
327
+ )
328
+
329
+ def forward(self, x):
330
+ return self.block(x)
331
+
332
+
333
+ class CDBlock(nn.Module):
334
+ def __init__(self, dim, dim_out, time_steps):
335
+ super().__init__()
336
+ if len(time_steps) > 1:
337
+ self.block = nn.Sequential(
338
+ nn.Conv2d(dim * len(time_steps), dim, 1), nn.ReLU(),
339
+ nn.Conv2d(dim, dim_out, 3, padding=1), nn.ReLU(),
340
+ )
341
+ else:
342
+ self.block = nn.Sequential(nn.Conv2d(dim, dim_out, 3, padding=1), nn.ReLU())
343
+
344
+ def forward(self, x):
345
+ return self.block(x)
346
+
347
+
348
+ class cd_head_v2(nn.Module):
349
+ """Change detection head (version 2)."""
350
+
351
+ def __init__(self, feat_scales, out_channels=2, inner_channel=None, channel_multiplier=None, img_size=256, time_steps=None):
352
+ super().__init__()
353
+ self.feat_scales = sorted(list(feat_scales), reverse=True)
354
+ self.in_channels = _get_in_channels(self.feat_scales, inner_channel, channel_multiplier)
355
+ self.img_size, self.time_steps = img_size, time_steps
356
+ self.decoder = nn.ModuleList()
357
+ for i in range(len(self.feat_scales)):
358
+ dim = _get_in_channels([self.feat_scales[i]], inner_channel, channel_multiplier)
359
+ self.decoder.append(CDBlock(dim, dim, time_steps))
360
+ if i < len(self.feat_scales) - 1:
361
+ dim_out = _get_in_channels([self.feat_scales[i + 1]], inner_channel, channel_multiplier)
362
+ self.decoder.append(AttentionBlock(dim, dim_out))
363
+ self.clfr_stg1 = nn.Conv2d(dim_out, 64, 3, padding=1)
364
+ self.clfr_stg2 = nn.Conv2d(64, out_channels, 3, padding=1)
365
+ self.relu = nn.ReLU()
366
+
367
+ def forward(self, feats_A, feats_B):
368
+ lvl, x = 0, None
369
+ for layer in self.decoder:
370
+ if isinstance(layer, CDBlock):
371
+ f_A = feats_A[0][self.feat_scales[lvl]]
372
+ f_B = feats_B[0][self.feat_scales[lvl]]
373
+ if len(self.time_steps) > 1:
374
+ for i in range(1, len(self.time_steps)):
375
+ f_A = torch.cat((f_A, feats_A[i][self.feat_scales[lvl]]), dim=1)
376
+ f_B = torch.cat((f_B, feats_B[i][self.feat_scales[lvl]]), dim=1)
377
+ diff = torch.abs(layer(f_A) - layer(f_B))
378
+ if lvl > 0:
379
+ diff = diff + x
380
+ lvl += 1
381
+ else:
382
+ diff = layer(diff)
383
+ x = F.interpolate(diff, scale_factor=2, mode="bilinear")
384
+ return self.clfr_stg2(self.relu(self.clfr_stg1(x)))
385
+
386
+
387
+ # ===========================================================================
388
+ # Diffusion utilities
389
+ # ===========================================================================
390
+
391
+ def _precompute_alpha_tables(scheduler):
392
+ ac = scheduler.alphas_cumprod.numpy()
393
+ return np.sqrt(np.append(1.0, ac))
394
+
395
+
396
+ def _q_sample(x_start, continuous_sqrt_alpha_cumprod, noise=None):
397
+ if noise is None:
398
+ noise = torch.randn_like(x_start)
399
+ return continuous_sqrt_alpha_cumprod * x_start + (1 - continuous_sqrt_alpha_cumprod ** 2).sqrt() * noise
400
+
401
+
402
+ @torch.no_grad()
403
+ def _extract_features(model, x, t, sqrt_alphas):
404
+ b = x.shape[0]
405
+ lvl = torch.FloatTensor(
406
+ np.random.uniform(sqrt_alphas[t - 1], sqrt_alphas[t], size=b)
407
+ ).to(x.device).view(b, -1)
408
+ noise = torch.randn_like(x)
409
+ x_noisy = _q_sample(x, lvl.view(-1, 1, 1, 1), noise)
410
+ return model(x_noisy, lvl, feat_need=True)
411
+
412
+
413
+ # ===========================================================================
414
+ # Pipeline
415
+ # ===========================================================================
416
+
417
+ class DDPMCDPipeline(DiffusionPipeline):
418
+ """DDPM-based change detection. Load with trust_remote_code=True.
419
+ For consolidated ddpm-cd repo with multiple cd_head variants, pass cd_head_subfolder
420
+ (e.g. 'levir-50-100', 'whu-50-100-400', 'cdd-50-100', etc.) when loading."""
421
+
422
+ def __init__(self, unet, scheduler, cd_head=None, cd_head_subfolder=None):
423
+ super().__init__()
424
+ self.register_modules(unet=unet, scheduler=scheduler)
425
+ self.cd_head = cd_head
426
+ self._cd_head_subfolder = cd_head_subfolder
427
+ # Infer base path from unet config (dirname of unet subfolder = model root)
428
+ unet_path = getattr(getattr(unet, "config", None), "_name_or_path", None)
429
+ self._cd_head_base_path = os.path.dirname(unet_path) if unet_path else None
430
+
431
+ def _load_cd_head_if_needed(self):
432
+ """Lazy-load cd_head from disk when first needed (path inferred from unet)."""
433
+ if self.cd_head is not None:
434
+ return
435
+ base = self._cd_head_base_path
436
+ if base is None:
437
+ cfg = getattr(self.unet, "config", None)
438
+ base = os.path.dirname(getattr(cfg, "_name_or_path", "")) if cfg else None
439
+ if not base or not os.path.isdir(base):
440
+ return # no cd_head (e.g. pretrained-only model)
441
+ subfolder = self._cd_head_subfolder
442
+ if subfolder:
443
+ cd_dir = os.path.join(base, "cd_head", subfolder)
444
+ else:
445
+ cd_dir = os.path.join(base, "cd_head")
446
+ if not os.path.isfile(os.path.join(cd_dir, "config.json")):
447
+ # Consolidated repo: cd_head_subfolder is required
448
+ subdirs = sorted([d for d in os.listdir(cd_dir) if os.path.isdir(os.path.join(cd_dir, d))])
449
+ raise RuntimeError(
450
+ "DDPMCDPipeline requires cd_head_subfolder when loading from consolidated ddpm-cd repo. "
451
+ f"Available: {subdirs}. Example: from_pretrained(..., cd_head_subfolder='levir-50-100')"
452
+ )
453
+ if not os.path.isdir(cd_dir):
454
+ return # no cd_head (e.g. pretrained-only model)
455
+ with open(os.path.join(cd_dir, "config.json")) as f:
456
+ cfg = json.load(f)
457
+ ch = cd_head_v2(**cfg)
458
+ for name in ("diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.bin"):
459
+ p = os.path.join(cd_dir, name)
460
+ if os.path.exists(p):
461
+ if p.endswith(".safetensors"):
462
+ from safetensors.torch import load_file
463
+ ch.load_state_dict(load_file(p, device="cpu"))
464
+ else:
465
+ try:
466
+ s = torch.load(p, map_location="cpu", weights_only=True)
467
+ except TypeError:
468
+ s = torch.load(p, map_location="cpu")
469
+ ch.load_state_dict(s.state_dict() if hasattr(s, "state_dict") else s)
470
+ break
471
+ self.cd_head = ch
472
+
473
+ def load_cd_head(self, pretrained_model_name_or_path=None, subfolder=None):
474
+ """Manually load cd_head from the given path (or infer from unet).
475
+ subfolder: e.g. 'levir-50-100', 'whu-50-100-400' for consolidated ddpm-cd repo."""
476
+ if pretrained_model_name_or_path:
477
+ self._cd_head_base_path = pretrained_model_name_or_path
478
+ if subfolder is not None:
479
+ self._cd_head_subfolder = subfolder
480
+ self._load_cd_head_if_needed()
481
+
482
+ @classmethod
483
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
484
+ cd_head_subfolder = kwargs.pop("cd_head_subfolder", None)
485
+ pipe = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
486
+ pipe._cd_head_base_path = pretrained_model_name_or_path if os.path.isdir(pretrained_model_name_or_path) else None
487
+ pipe._cd_head_subfolder = cd_head_subfolder
488
+ pipe._load_cd_head_if_needed()
489
+ return pipe
490
+
491
+ @torch.no_grad()
492
+ def __call__(self, image_A, image_B, timesteps=None, feat_type="dec"):
493
+ self._load_cd_head_if_needed()
494
+ if self.cd_head is None:
495
+ raise RuntimeError("DDPMCDPipeline requires cd_head. Could not load from disk.")
496
+ timesteps = timesteps or [50, 100]
497
+ sqrt_a = _precompute_alpha_tables(self.scheduler)
498
+ feats_A, feats_B = [], []
499
+ for t in timesteps:
500
+ fe_A, fd_A = _extract_features(self.unet, image_A, t, sqrt_a)
501
+ fe_B, fd_B = _extract_features(self.unet, image_B, t, sqrt_a)
502
+ feats_A.append(fd_A if feat_type == "dec" else fe_A)
503
+ feats_B.append(fd_B if feat_type == "dec" else fe_B)
504
+ return self.cd_head(feats_A, feats_B)
505
+
506
+ @torch.no_grad()
507
+ def generate(self, batch_size=1, in_channels=3, image_size=256, num_inference_steps=None, generator=None):
508
+ device = next(self.unet.parameters()).device
509
+ steps = num_inference_steps or self.scheduler.config.num_train_timesteps
510
+ sqrt_a = _precompute_alpha_tables(self.scheduler)
511
+ image = torch.randn((batch_size, in_channels, image_size, image_size), device=device, generator=generator)
512
+ self.scheduler.set_timesteps(steps)
513
+ for t in tqdm(self.scheduler.timesteps, desc="Sampling"):
514
+ idx = min(int(t) + 1, len(sqrt_a) - 1)
515
+ lvl = torch.FloatTensor([sqrt_a[idx]]).repeat(batch_size, 1).to(device)
516
+ noise_pred = self.unet(image, lvl)
517
+ image = self.scheduler.step(noise_pred, t, image).prev_sample
518
+ return image
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DDPMScheduler",
3
+ "_diffusers_version": "0.36.0",
4
+ "beta_end": 0.01,
5
+ "beta_schedule": "squaredcos_cap_v2",
6
+ "beta_start": 1e-06,
7
+ "clip_sample": false,
8
+ "clip_sample_range": 1.0,
9
+ "dynamic_thresholding_ratio": 0.995,
10
+ "num_train_timesteps": 2000,
11
+ "prediction_type": "epsilon",
12
+ "rescale_betas_zero_snr": false,
13
+ "sample_max_value": 1.0,
14
+ "steps_offset": 0,
15
+ "thresholding": false,
16
+ "timestep_spacing": "leading",
17
+ "trained_betas": null,
18
+ "variance_type": "fixed_small"
19
+ }
unet/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNet",
3
+ "_diffusers_version": "0.36.0",
4
+ "_name_or_path": "D:\\sakura-project\\ddpm-cd-diffusers\\models\\BiliSakura\\BiliSakura\\ddpm-cd-pretrained-256\\unet",
5
+ "attn_res": [
6
+ 16
7
+ ],
8
+ "channel_mults": [
9
+ 1,
10
+ 2,
11
+ 4,
12
+ 8,
13
+ 8
14
+ ],
15
+ "dropout": 0.2,
16
+ "image_size": 256,
17
+ "in_channel": 3,
18
+ "inner_channel": 128,
19
+ "norm_groups": 32,
20
+ "out_channel": 3,
21
+ "res_blocks": 2,
22
+ "with_noise_level_emb": true
23
+ }
unet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92980ede4037dcfec88f4626dd0353d74fa8e303fd867c3d426a6bb5cd416649
3
+ size 1564231460
unet/unet.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Self-contained DDPMCDPipeline for change detection.
3
+ All custom code (UNet, cd_head, diffusion utils) in one file - no external repo needed.
4
+
5
+ Usage::
6
+
7
+ from diffusers import DiffusionPipeline
8
+
9
+ pipe = DiffusionPipeline.from_pretrained(
10
+ "path/to/ddpm-cd-levir-50-100",
11
+ trust_remote_code=True,
12
+ )
13
+ change_map = pipe(image_A, image_B, timesteps=[50, 100])
14
+ """
15
+
16
+ import json
17
+ import math
18
+ import os
19
+ from inspect import isfunction
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ from diffusers import DDPMScheduler
26
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
27
+ from diffusers.models.modeling_utils import ModelMixin # ModelMixin subclasses nn.Module
28
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
29
+ from tqdm.auto import tqdm
30
+
31
+
32
+ # ===========================================================================
33
+ # UNet (SR3-style) - all components inlined
34
+ # ===========================================================================
35
+
36
+ def _exists(x):
37
+ return x is not None
38
+
39
+
40
+ def _default(val, d):
41
+ if _exists(val):
42
+ return val
43
+ return d() if isfunction(d) else d
44
+
45
+
46
+ class PositionalEncoding(nn.Module):
47
+ def __init__(self, dim):
48
+ super().__init__()
49
+ self.dim = dim
50
+
51
+ def forward(self, noise_level):
52
+ count = self.dim // 2
53
+ step = torch.arange(count, dtype=noise_level.dtype, device=noise_level.device) / count
54
+ encoding = noise_level.unsqueeze(1) * torch.exp(-math.log(1e4) * step.unsqueeze(0))
55
+ return torch.cat([torch.sin(encoding), torch.cos(encoding)], dim=-1)
56
+
57
+
58
+ class FeatureWiseAffine(nn.Module):
59
+ def __init__(self, in_channels, out_channels, use_affine_level=False):
60
+ super().__init__()
61
+ self.use_affine_level = use_affine_level
62
+ self.noise_func = nn.Sequential(nn.Linear(in_channels, out_channels * (1 + self.use_affine_level)))
63
+
64
+ def forward(self, x, noise_embed):
65
+ batch = x.shape[0]
66
+ if self.use_affine_level:
67
+ gamma, beta = self.noise_func(noise_embed).view(batch, -1, 1, 1).chunk(2, dim=1)
68
+ x = (1 + gamma) * x + beta
69
+ else:
70
+ x = x + self.noise_func(noise_embed).view(batch, -1, 1, 1)
71
+ return x
72
+
73
+
74
+ class Swish(nn.Module):
75
+ def forward(self, x):
76
+ return x * torch.sigmoid(x)
77
+
78
+
79
+ class Upsample(nn.Module):
80
+ def __init__(self, dim):
81
+ super().__init__()
82
+ self.up = nn.Upsample(scale_factor=2, mode="nearest")
83
+ self.conv = nn.Conv2d(dim, dim, 3, padding=1)
84
+
85
+ def forward(self, x):
86
+ return self.conv(self.up(x))
87
+
88
+
89
+ class Downsample(nn.Module):
90
+ def __init__(self, dim):
91
+ super().__init__()
92
+ self.conv = nn.Conv2d(dim, dim, 3, 2, 1)
93
+
94
+ def forward(self, x):
95
+ return self.conv(x)
96
+
97
+
98
+ class Block(nn.Module):
99
+ def __init__(self, dim, dim_out, groups=32, dropout=0):
100
+ super().__init__()
101
+ self.block = nn.Sequential(
102
+ nn.GroupNorm(groups, dim),
103
+ Swish(),
104
+ nn.Dropout(dropout) if dropout != 0 else nn.Identity(),
105
+ nn.Conv2d(dim, dim_out, 3, padding=1),
106
+ )
107
+
108
+ def forward(self, x):
109
+ return self.block(x)
110
+
111
+
112
+ class ResnetBlock(nn.Module):
113
+ def __init__(self, dim, dim_out, noise_level_emb_dim=None, dropout=0, use_affine_level=False, norm_groups=32):
114
+ super().__init__()
115
+ self.noise_func = FeatureWiseAffine(noise_level_emb_dim, dim_out, use_affine_level)
116
+ self.block1 = Block(dim, dim_out, groups=norm_groups)
117
+ self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout)
118
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
119
+
120
+ def forward(self, x, time_emb):
121
+ h = self.block1(x)
122
+ h = self.noise_func(h, time_emb)
123
+ h = self.block2(h)
124
+ return h + self.res_conv(x)
125
+
126
+
127
+ class SelfAttention(nn.Module):
128
+ def __init__(self, in_channel, n_head=1, norm_groups=32):
129
+ super().__init__()
130
+ self.n_head = n_head
131
+ self.norm = nn.GroupNorm(norm_groups, in_channel)
132
+ self.qkv = nn.Conv2d(in_channel, in_channel * 3, 1, bias=False)
133
+ self.out = nn.Conv2d(in_channel, in_channel, 1)
134
+
135
+ def forward(self, input):
136
+ batch, channel, height, width = input.shape
137
+ n_head, head_dim = self.n_head, channel // self.n_head
138
+ norm = self.norm(input)
139
+ qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, height, width)
140
+ query, key, value = qkv.chunk(3, dim=2)
141
+ attn = torch.einsum("bnchw, bncyx -> bnhwyx", query, key).contiguous() / math.sqrt(channel)
142
+ attn = torch.softmax(attn.view(batch, n_head, height, width, -1), -1)
143
+ attn = attn.view(batch, n_head, height, width, height, width)
144
+ out = torch.einsum("bnhwyx, bncyx -> bnchw", attn, value).contiguous()
145
+ return self.out(out.view(batch, channel, height, width)) + input
146
+
147
+
148
+ class ResnetBlocWithAttn(nn.Module):
149
+ def __init__(self, dim, dim_out, *, noise_level_emb_dim=None, norm_groups=32, dropout=0, with_attn=False):
150
+ super().__init__()
151
+ self.with_attn = with_attn
152
+ self.res_block = ResnetBlock(dim, dim_out, noise_level_emb_dim, norm_groups=norm_groups, dropout=dropout)
153
+ self.attn = SelfAttention(dim_out, norm_groups=norm_groups) if with_attn else None
154
+
155
+ def forward(self, x, time_emb):
156
+ x = self.res_block(x, time_emb)
157
+ if self.with_attn:
158
+ x = self.attn(x)
159
+ return x
160
+
161
+
162
+ class UNet(ModelMixin, ConfigMixin):
163
+ """SR3-style UNet with noise-level conditioning. Supports feat_need=True for intermediate features."""
164
+
165
+ @register_to_config
166
+ def __init__(
167
+ self,
168
+ in_channel=6,
169
+ out_channel=3,
170
+ inner_channel=32,
171
+ norm_groups=32,
172
+ channel_mults=(1, 2, 4, 8, 8),
173
+ attn_res=(8,),
174
+ res_blocks=3,
175
+ dropout=0,
176
+ with_noise_level_emb=True,
177
+ image_size=128,
178
+ ):
179
+ super().__init__()
180
+ noise_level_channel = inner_channel if with_noise_level_emb else None
181
+ self.noise_level_mlp = (
182
+ nn.Sequential(
183
+ PositionalEncoding(inner_channel),
184
+ nn.Linear(inner_channel, inner_channel * 4),
185
+ Swish(),
186
+ nn.Linear(inner_channel * 4, inner_channel),
187
+ )
188
+ if with_noise_level_emb
189
+ else None
190
+ )
191
+
192
+ num_mults = len(channel_mults)
193
+ pre_channel, feat_channels, now_res = inner_channel, [inner_channel], image_size
194
+ self.init_conv = nn.Conv2d(in_channel, inner_channel, 3, padding=1)
195
+
196
+ downs = []
197
+ for ind in range(num_mults):
198
+ use_attn = now_res in attn_res
199
+ channel_mult = inner_channel * channel_mults[ind]
200
+ for _ in range(res_blocks):
201
+ downs.append(
202
+ ResnetBlocWithAttn(
203
+ pre_channel, channel_mult,
204
+ noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups,
205
+ dropout=dropout, with_attn=use_attn,
206
+ )
207
+ )
208
+ feat_channels.append(channel_mult)
209
+ pre_channel = channel_mult
210
+ if ind < num_mults - 1:
211
+ downs.append(Downsample(pre_channel))
212
+ feat_channels.append(pre_channel)
213
+ now_res = now_res // 2
214
+ self.downs = nn.ModuleList(downs)
215
+
216
+ self.mid = nn.ModuleList([
217
+ ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel,
218
+ norm_groups=norm_groups, dropout=dropout, with_attn=True),
219
+ ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel,
220
+ norm_groups=norm_groups, dropout=dropout, with_attn=False),
221
+ ])
222
+
223
+ ups = []
224
+ for ind in reversed(range(num_mults)):
225
+ use_attn = now_res in attn_res
226
+ channel_mult = inner_channel * channel_mults[ind]
227
+ for _ in range(res_blocks + 1):
228
+ ups.append(
229
+ ResnetBlocWithAttn(
230
+ pre_channel + feat_channels.pop(), channel_mult,
231
+ noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups,
232
+ dropout=dropout, with_attn=use_attn,
233
+ )
234
+ )
235
+ pre_channel = channel_mult
236
+ if ind > 0:
237
+ ups.append(Upsample(pre_channel))
238
+ now_res = now_res * 2
239
+ self.ups = nn.ModuleList(ups)
240
+ self.final_conv = Block(pre_channel, _default(out_channel, lambda: in_channel), groups=norm_groups)
241
+
242
+ def forward(self, x, time, feat_need=False):
243
+ t = self.noise_level_mlp(time) if _exists(self.noise_level_mlp) else None
244
+ x = self.init_conv(x)
245
+ feats = [x]
246
+ for layer in self.downs:
247
+ x = layer(x, t) if isinstance(layer, ResnetBlocWithAttn) else layer(x)
248
+ feats.append(x)
249
+ fe = feats.copy() if feat_need else None
250
+ for layer in self.mid:
251
+ x = layer(x, t) if isinstance(layer, ResnetBlocWithAttn) else layer(x)
252
+ fd = [] if feat_need else None
253
+ for layer in self.ups:
254
+ if isinstance(layer, ResnetBlocWithAttn):
255
+ x = layer(torch.cat((x, feats.pop()), dim=1), t)
256
+ if feat_need:
257
+ fd.append(x)
258
+ else:
259
+ x = layer(x)
260
+ x = self.final_conv(x)
261
+ return (fe, list(reversed(fd))) if feat_need else x
262
+
263
+
264
+ # ===========================================================================
265
+ # Change detection head
266
+ # ===========================================================================
267
+
268
+ class ChannelSELayer(nn.Module):
269
+ def __init__(self, num_channels, reduction_ratio=2):
270
+ super().__init__()
271
+ reduced = num_channels // reduction_ratio
272
+ self.fc1 = nn.Linear(num_channels, reduced, bias=True)
273
+ self.fc2 = nn.Linear(reduced, num_channels, bias=True)
274
+ self.relu, self.sigmoid = nn.ReLU(), nn.Sigmoid()
275
+
276
+ def forward(self, x):
277
+ b, c, _, _ = x.size()
278
+ s = x.view(b, c, -1).mean(dim=2)
279
+ s = self.sigmoid(self.fc2(self.relu(self.fc1(s)))).view(b, c, 1, 1)
280
+ return x * s
281
+
282
+
283
+ class SpatialSELayer(nn.Module):
284
+ def __init__(self, num_channels):
285
+ super().__init__()
286
+ self.conv = nn.Conv2d(num_channels, 1, 1)
287
+ self.sigmoid = nn.Sigmoid()
288
+
289
+ def forward(self, x, weights=None):
290
+ b, c, h, w = x.size()
291
+ out = F.conv2d(x, weights.view(1, c, 1, 1)) if weights is not None else self.conv(x)
292
+ return x * self.sigmoid(out).view(b, 1, h, w)
293
+
294
+
295
+ class ChannelSpatialSELayer(nn.Module):
296
+ def __init__(self, num_channels, reduction_ratio=2):
297
+ super().__init__()
298
+ self.cSE = ChannelSELayer(num_channels, reduction_ratio)
299
+ self.sSE = SpatialSELayer(num_channels)
300
+
301
+ def forward(self, x):
302
+ return self.cSE(x) + self.sSE(x)
303
+
304
+
305
+ def _get_in_channels(feat_scales, inner_channel, channel_multiplier):
306
+ m, cm = inner_channel, channel_multiplier
307
+ r = 0
308
+ for s in feat_scales:
309
+ if s < 3: r += m * cm[0]
310
+ elif s < 6: r += m * cm[1]
311
+ elif s < 9: r += m * cm[2]
312
+ elif s < 12: r += m * cm[3]
313
+ elif s < 15: r += m * cm[4]
314
+ else: raise ValueError("feat_scales 0<=s<=14")
315
+ return r
316
+
317
+
318
+ class AttentionBlock(nn.Module):
319
+ def __init__(self, dim, dim_out):
320
+ super().__init__()
321
+ self.block = nn.Sequential(
322
+ nn.Conv2d(dim, dim_out, 3, padding=1),
323
+ nn.ReLU(),
324
+ ChannelSpatialSELayer(dim_out, 2),
325
+ )
326
+
327
+ def forward(self, x):
328
+ return self.block(x)
329
+
330
+
331
+ class CDBlock(nn.Module):
332
+ def __init__(self, dim, dim_out, time_steps):
333
+ super().__init__()
334
+ if len(time_steps) > 1:
335
+ self.block = nn.Sequential(
336
+ nn.Conv2d(dim * len(time_steps), dim, 1), nn.ReLU(),
337
+ nn.Conv2d(dim, dim_out, 3, padding=1), nn.ReLU(),
338
+ )
339
+ else:
340
+ self.block = nn.Sequential(nn.Conv2d(dim, dim_out, 3, padding=1), nn.ReLU())
341
+
342
+ def forward(self, x):
343
+ return self.block(x)
344
+
345
+
346
+ class cd_head_v2(nn.Module):
347
+ """Change detection head (version 2)."""
348
+
349
+ def __init__(self, feat_scales, out_channels=2, inner_channel=None, channel_multiplier=None, img_size=256, time_steps=None):
350
+ super().__init__()
351
+ self.feat_scales = sorted(list(feat_scales), reverse=True)
352
+ self.in_channels = _get_in_channels(self.feat_scales, inner_channel, channel_multiplier)
353
+ self.img_size, self.time_steps = img_size, time_steps
354
+ self.decoder = nn.ModuleList()
355
+ for i in range(len(self.feat_scales)):
356
+ dim = _get_in_channels([self.feat_scales[i]], inner_channel, channel_multiplier)
357
+ self.decoder.append(CDBlock(dim, dim, time_steps))
358
+ if i < len(self.feat_scales) - 1:
359
+ dim_out = _get_in_channels([self.feat_scales[i + 1]], inner_channel, channel_multiplier)
360
+ self.decoder.append(AttentionBlock(dim, dim_out))
361
+ self.clfr_stg1 = nn.Conv2d(dim_out, 64, 3, padding=1)
362
+ self.clfr_stg2 = nn.Conv2d(64, out_channels, 3, padding=1)
363
+ self.relu = nn.ReLU()
364
+
365
+ def forward(self, feats_A, feats_B):
366
+ lvl, x = 0, None
367
+ for layer in self.decoder:
368
+ if isinstance(layer, CDBlock):
369
+ f_A = feats_A[0][self.feat_scales[lvl]]
370
+ f_B = feats_B[0][self.feat_scales[lvl]]
371
+ if len(self.time_steps) > 1:
372
+ for i in range(1, len(self.time_steps)):
373
+ f_A = torch.cat((f_A, feats_A[i][self.feat_scales[lvl]]), dim=1)
374
+ f_B = torch.cat((f_B, feats_B[i][self.feat_scales[lvl]]), dim=1)
375
+ diff = torch.abs(layer(f_A) - layer(f_B))
376
+ if lvl > 0:
377
+ diff = diff + x
378
+ lvl += 1
379
+ else:
380
+ diff = layer(diff)
381
+ x = F.interpolate(diff, scale_factor=2, mode="bilinear")
382
+ return self.clfr_stg2(self.relu(self.clfr_stg1(x)))
383
+
384
+
385
+ # ===========================================================================
386
+ # Diffusion utilities
387
+ # ===========================================================================
388
+
389
+ def _precompute_alpha_tables(scheduler):
390
+ ac = scheduler.alphas_cumprod.numpy()
391
+ return np.sqrt(np.append(1.0, ac))
392
+
393
+
394
+ def _q_sample(x_start, continuous_sqrt_alpha_cumprod, noise=None):
395
+ if noise is None:
396
+ noise = torch.randn_like(x_start)
397
+ return continuous_sqrt_alpha_cumprod * x_start + (1 - continuous_sqrt_alpha_cumprod ** 2).sqrt() * noise
398
+
399
+
400
+ @torch.no_grad()
401
+ def _extract_features(model, x, t, sqrt_alphas):
402
+ b = x.shape[0]
403
+ lvl = torch.FloatTensor(
404
+ np.random.uniform(sqrt_alphas[t - 1], sqrt_alphas[t], size=b)
405
+ ).to(x.device).view(b, -1)
406
+ noise = torch.randn_like(x)
407
+ x_noisy = _q_sample(x, lvl.view(-1, 1, 1, 1), noise)
408
+ return model(x_noisy, lvl, feat_need=True)
409
+
410
+
411
+ # ===========================================================================
412
+ # Pipeline
413
+ # ===========================================================================
414
+
415
+ class DDPMCDPipeline(DiffusionPipeline):
416
+ """DDPM-based change detection. Load with trust_remote_code=True."""
417
+
418
+ def __init__(self, unet, scheduler, cd_head=None):
419
+ super().__init__()
420
+ self.register_modules(unet=unet, scheduler=scheduler)
421
+ self.cd_head = cd_head
422
+ self._cd_head_base_path = None # set when loaded via from_pretrained
423
+
424
+ def _load_cd_head_if_needed(self):
425
+ """Lazy-load cd_head from disk when first needed (path inferred from unet)."""
426
+ if self.cd_head is not None:
427
+ return
428
+ base = self._cd_head_base_path
429
+ if base is None:
430
+ cfg = getattr(self.unet, "config", None)
431
+ base = os.path.dirname(getattr(cfg, "_name_or_path", "")) if cfg else None
432
+ if not base or not os.path.isdir(base):
433
+ raise RuntimeError("Cannot find model path to load cd_head. Use load_cd_head(path) or load from a full pipeline directory.")
434
+ cd_dir = os.path.join(base, "cd_head")
435
+ if not os.path.isdir(cd_dir):
436
+ raise RuntimeError(f"cd_head directory not found at {cd_dir}")
437
+ with open(os.path.join(cd_dir, "config.json")) as f:
438
+ cfg = json.load(f)
439
+ ch = cd_head_v2(**cfg)
440
+ for name in ("diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.bin"):
441
+ p = os.path.join(cd_dir, name)
442
+ if os.path.exists(p):
443
+ if p.endswith(".safetensors"):
444
+ from safetensors.torch import load_file
445
+ ch.load_state_dict(load_file(p, device="cpu"))
446
+ else:
447
+ try:
448
+ s = torch.load(p, map_location="cpu", weights_only=True)
449
+ except TypeError:
450
+ s = torch.load(p, map_location="cpu")
451
+ ch.load_state_dict(s.state_dict() if hasattr(s, "state_dict") else s)
452
+ break
453
+ self.cd_head = ch
454
+
455
+ def load_cd_head(self, pretrained_model_name_or_path=None):
456
+ """Manually load cd_head from the given path (or infer from unet)."""
457
+ if pretrained_model_name_or_path:
458
+ self._cd_head_base_path = pretrained_model_name_or_path
459
+ self._load_cd_head_if_needed()
460
+
461
+ @classmethod
462
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
463
+ pipe = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
464
+ pipe._cd_head_base_path = pretrained_model_name_or_path if os.path.isdir(pretrained_model_name_or_path) else None
465
+ pipe._load_cd_head_if_needed()
466
+ return pipe
467
+
468
+ @torch.no_grad()
469
+ def __call__(self, image_A, image_B, timesteps=None, feat_type="dec"):
470
+ self._load_cd_head_if_needed()
471
+ if self.cd_head is None:
472
+ raise RuntimeError("DDPMCDPipeline requires cd_head. Could not load from disk.")
473
+ timesteps = timesteps or [50, 100]
474
+ sqrt_a = _precompute_alpha_tables(self.scheduler)
475
+ feats_A, feats_B = [], []
476
+ for t in timesteps:
477
+ fe_A, fd_A = _extract_features(self.unet, image_A, t, sqrt_a)
478
+ fe_B, fd_B = _extract_features(self.unet, image_B, t, sqrt_a)
479
+ feats_A.append(fd_A if feat_type == "dec" else fe_A)
480
+ feats_B.append(fd_B if feat_type == "dec" else fe_B)
481
+ return self.cd_head(feats_A, feats_B)
482
+
483
+ @torch.no_grad()
484
+ def generate(self, batch_size=1, in_channels=3, image_size=256, num_inference_steps=None, generator=None):
485
+ device = next(self.unet.parameters()).device
486
+ steps = num_inference_steps or self.scheduler.config.num_train_timesteps
487
+ sqrt_a = _precompute_alpha_tables(self.scheduler)
488
+ image = torch.randn((batch_size, in_channels, image_size, image_size), device=device, generator=generator)
489
+ self.scheduler.set_timesteps(steps)
490
+ for t in tqdm(self.scheduler.timesteps, desc="Sampling"):
491
+ idx = min(int(t) + 1, len(sqrt_a) - 1)
492
+ lvl = torch.FloatTensor([sqrt_a[idx]]).repeat(batch_size, 1).to(device)
493
+ noise_pred = self.unet(image, lvl)
494
+ image = self.scheduler.step(noise_pred, t, image).prev_sample
495
+ return image