BiliSakura commited on
Commit
b6acc0a
·
verified ·
1 Parent(s): 714cd75

Add files using upload-large-folder tool

Browse files
.gitattributes CHANGED
@@ -1,35 +1,3 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
  *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
1
  *.ckpt filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  *.safetensors filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
README.md ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ library_name: diffusers
4
+ pipeline_tag: text-to-image
5
+ tags:
6
+ - remote-sensing
7
+ - diffusion
8
+ - controlnet
9
+ - custom-pipeline
10
+ language:
11
+ - en
12
+ ---
13
+
14
+ > [!WARNING] we do not have a full checkpoint conversion validation, if you encounter pipeline loading failure and unsidered output, please contact me via bili_sakura@zju.edu.cn
15
+
16
+ # BiliSakura/CRS-Diff
17
+
18
+ Diffusers-style packaging for the CRS-Diff checkpoint, with a custom Hugging Face `DiffusionPipeline` implementation.
19
+
20
+ ## Model Details
21
+
22
+ - **Base project**: `CRS-Diff` (Controllable Remote Sensing Image Generation with Diffusion Model)
23
+ - **Checkpoint source**: `/root/worksapce/models/raw/CRS-Diff/last.ckpt`
24
+ - **Pipeline class**: `CRSDiffPipeline` (in `pipeline.py`)
25
+ - **Scheduler**: `DDIMScheduler`
26
+ - **Resolution**: 512x512 (default in training/inference config)
27
+
28
+ ## Repository Structure
29
+
30
+ ```text
31
+ CRS-Diff/
32
+ pipeline.py
33
+ modular_pipeline.py
34
+ crs_core/
35
+ autoencoder.py
36
+ text_encoder.py
37
+ local_adapter.py
38
+ global_adapter.py
39
+ metadata_embedding.py
40
+ modules/
41
+ model_index.json
42
+ scheduler/
43
+ scheduler_config.json
44
+ unet/
45
+ vae/
46
+ text_encoder/
47
+ local_adapter/
48
+ global_content_adapter/
49
+ global_text_adapter/
50
+ metadata_encoder/
51
+ ```
52
+
53
+ ## Usage
54
+
55
+ Install dependencies first:
56
+
57
+ ```bash
58
+ pip install diffusers transformers torch torchvision omegaconf einops safetensors pytorch-lightning
59
+ ```
60
+
61
+ Load the pipeline (local path or Hub repo), then run inference:
62
+
63
+ ```python
64
+ import torch
65
+ import numpy as np
66
+ from diffusers import DiffusionPipeline
67
+
68
+ pipe = DiffusionPipeline.from_pretrained(
69
+ "/root/worksapce/models/BiliSakura/CRS-Diff",
70
+ custom_pipeline="pipeline.py",
71
+ trust_remote_code=True,
72
+ model_path="/root/worksapce/models/BiliSakura/CRS-Diff",
73
+ )
74
+ pipe = pipe.to("cuda")
75
+
76
+ # Example placeholder controls; replace with real CRS controls.
77
+ b = 1
78
+ local_control = torch.zeros((b, 18, 512, 512), device="cuda", dtype=torch.float32)
79
+ global_control = torch.zeros((b, 1536), device="cuda", dtype=torch.float32)
80
+ metadata = torch.zeros((b, 7), device="cuda", dtype=torch.float32)
81
+
82
+ out = pipe(
83
+ prompt=["a remote sensing image of an urban area"],
84
+ negative_prompt=["blurry, distorted, overexposed"],
85
+ local_control=local_control,
86
+ global_control=global_control,
87
+ metadata=metadata,
88
+ num_inference_steps=50,
89
+ guidance_scale=7.5,
90
+ eta=0.0,
91
+ strength=1.0,
92
+ global_strength=1.0,
93
+ output_type="pil",
94
+ )
95
+ image = out.images[0]
96
+ image.save("crs_diff_sample.png")
97
+ ```
98
+
99
+ ## Notes
100
+
101
+ - This repository is packaged in a diffusers-compatible layout with a custom pipeline.
102
+ - Loading path follows the same placeholder-aware custom pipeline pattern as HSIGene.
103
+ - Split component weights are provided in diffusers-style folders (`unet/`, `vae/`, adapters, and encoders).
104
+ - Monolithic `crs_model/last.ckpt` fallback is intentionally removed; this repo is split-components only.
105
+ - Legacy external source trees (`models/`, `ldm/`) are removed; runtime code is in lightweight `crs_core/`.
106
+ - `CRSDiffPipeline` expects CRS-specific condition tensors (`local_control`, `global_control`, `metadata`).
107
+ - If you publish to Hugging Face Hub, keep `trust_remote_code=True` when loading.
108
+
109
+ ## Citation
110
+
111
+ ```bibtex
112
+ @article{tang2024crs,
113
+ title={Crs-diff: Controllable remote sensing image generation with diffusion model},
114
+ author={Tang, Datao and Cao, Xiangyong and Hou, Xingsong and Jiang, Zhongyuan and Liu, Junmin and Meng, Deyu},
115
+ journal={IEEE Transactions on Geoscience and Remote Sensing},
116
+ year={2024},
117
+ publisher={IEEE}
118
+ }
119
+ ```
__pycache__/modular_pipeline.cpython-312.pyc ADDED
Binary file (9.96 kB). View file
 
crs_core/__init__.py ADDED
File without changes
crs_core/autoencoder.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from crs_core.modules.diffusionmodules.model import Encoder, Decoder
5
+ from crs_core.modules.distributions.distributions import DiagonalGaussianDistribution
6
+
7
+
8
+ class AutoencoderKL(nn.Module):
9
+ def __init__(self, ddconfig, lossconfig=None, embed_dim=4, **kwargs):
10
+ super().__init__()
11
+ del lossconfig, kwargs
12
+ self.encoder = Encoder(**ddconfig)
13
+ self.decoder = Decoder(**ddconfig)
14
+ assert ddconfig["double_z"]
15
+ self.quant_conv = nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
16
+ self.post_quant_conv = nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
17
+ self.embed_dim = embed_dim
18
+
19
+ def encode(self, x):
20
+ h = self.encoder(x)
21
+ moments = self.quant_conv(h)
22
+ return DiagonalGaussianDistribution(moments)
23
+
24
+ def decode(self, z):
25
+ z = self.post_quant_conv(z)
26
+ return self.decoder(z)
crs_core/global_adapter.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from einops import rearrange
3
+ import torch
4
+ import numpy
5
+
6
+ from crs_core.modules.attention import FeedForward
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ import torch.nn as nn
12
+
13
+ import torch
14
+ import math
15
+
16
+ class FixedPositionalEncoding(nn.Module):
17
+ def __init__(self, d_model, max_len=5000):
18
+ super(FixedPositionalEncoding, self).__init__()
19
+ pe = torch.zeros(max_len, d_model)
20
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
21
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
22
+ pe[:, 0::2] = torch.sin(position * div_term)
23
+ pe[:, 1::2] = torch.cos(position * div_term)
24
+ pe = pe.unsqueeze(0).transpose(0, 1)
25
+ self.register_buffer('pe', pe, persistent=False)
26
+
27
+ def forward(self, x):
28
+ x = x + self.pe[:x.size(0), :]
29
+ return x
30
+
31
+
32
+ class GlobalTextAdapter(nn.Module):
33
+ def __init__(self, in_dim, max_len=768):
34
+ super().__init__()
35
+ self.in_dim = in_dim
36
+ # self.positional_encoding = FixedPositionalEncoding(d_model=in_dim, max_len=max_len)
37
+
38
+ def forward(self, x):
39
+ # x = self.positional_encoding(x)
40
+ return x
41
+
42
+ class GlobalContentAdapter(nn.Module):
43
+ def __init__(self, in_dim, channel_mult=[2, 4]):
44
+ super().__init__()
45
+ dim_out1, mult1 = in_dim*channel_mult[0], channel_mult[0]*2
46
+ dim_out2, mult2 = in_dim*channel_mult[1], channel_mult[1]*2//channel_mult[0]
47
+ self.in_dim = in_dim
48
+ self.channel_mult = channel_mult
49
+
50
+ self.ff1 = FeedForward(in_dim, dim_out=dim_out1, mult=mult1, glu=True, dropout=0.1)
51
+ self.ff2 = FeedForward(dim_out1, dim_out=dim_out2, mult=mult2, glu=True, dropout=0.3)
52
+ self.norm1 = nn.LayerNorm(in_dim)
53
+ self.norm2 = nn.LayerNorm(dim_out1)
54
+
55
+ def forward(self, x):
56
+ x = self.ff1(self.norm1(x))
57
+ x = self.ff2(self.norm2(x))
58
+ x = rearrange(x, 'b (n d) -> b n d', n=self.channel_mult[-1], d=self.in_dim).contiguous()
59
+ return x
crs_core/local_adapter.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch as th
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from crs_core.modules.diffusionmodules.util import (
7
+ checkpoint,
8
+ conv_nd,
9
+ linear,
10
+ zero_module,
11
+ timestep_embedding,
12
+ )
13
+ from crs_core.modules.attention import SpatialTransformer
14
+ from crs_core.modules.diffusionmodules.openaimodel import UNetModel, TimestepBlock, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
15
+ from crs_core.utils import exists
16
+
17
+
18
+ class LocalTimestepEmbedSequential(nn.Sequential, TimestepBlock):
19
+ def forward(self, x, emb, context=None, local_features=None):
20
+ for layer in self:
21
+ if isinstance(layer, TimestepBlock):
22
+ x = layer(x, emb)
23
+ elif isinstance(layer, SpatialTransformer):
24
+ x = layer(x, context)
25
+ elif isinstance(layer, LocalResBlock):
26
+ x = layer(x, emb, local_features)
27
+ else:
28
+ x = layer(x)
29
+ return x
30
+
31
+
32
+ class FDN(nn.Module):
33
+ def __init__(self, norm_nc, label_nc):
34
+ super().__init__()
35
+ ks = 3
36
+ pw = ks // 2
37
+ self.param_free_norm = nn.GroupNorm(32, norm_nc, affine=False)
38
+ self.conv_gamma = nn.Conv2d(label_nc, norm_nc, kernel_size=ks, padding=pw)
39
+ self.conv_beta = nn.Conv2d(label_nc, norm_nc, kernel_size=ks, padding=pw)
40
+
41
+ def forward(self, x, local_features):
42
+ normalized = self.param_free_norm(x)
43
+ assert local_features.size()[2:] == x.size()[2:]
44
+ gamma = self.conv_gamma(local_features)
45
+ beta = self.conv_beta(local_features)
46
+ out = normalized * (1 + gamma) + beta
47
+ return out
48
+
49
+ class SelfAttention(nn.Module):
50
+ def __init__(self, in_dim):
51
+ super(SelfAttention, self).__init__()
52
+ # Query, Key, Value transformations
53
+ self.query_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
54
+ self.key_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
55
+ self.value_conv = nn.Conv2d(in_dim, in_dim, kernel_size=1)
56
+ # Softmax attention
57
+ self.softmax = nn.Softmax(dim=-1)
58
+
59
+ def forward(self, x):
60
+ batch, C, width, height = x.size()
61
+ query = self.query_conv(x).view(batch, -1, width * height).permute(0, 2, 1)
62
+ key = self.key_conv(x).view(batch, -1, width * height)
63
+ value = self.value_conv(x).view(batch, -1, width * height)
64
+
65
+ attention = self.softmax(torch.bmm(query, key))
66
+ out = torch.bmm(value, attention.permute(0, 2, 1))
67
+ out = out.view(batch, C, width, height)
68
+
69
+ return out + x # Skip connection
70
+
71
+ class EnhancedFDN(nn.Module):
72
+ def __init__(self, norm_nc, label_nc):
73
+ super(EnhancedFDN, self).__init__()
74
+ self.fdn = FDN(norm_nc, label_nc)
75
+ self.attention = SelfAttention(norm_nc)
76
+
77
+ def forward(self, x, local_features):
78
+ x = self.attention(x)
79
+ out = self.fdn(x, local_features)
80
+ return out
81
+
82
+
83
+ class LocalResBlock(nn.Module):
84
+ def __init__(
85
+ self,
86
+ channels,
87
+ emb_channels,
88
+ dropout,
89
+ out_channels=None,
90
+ dims=2,
91
+ use_checkpoint=False,
92
+ inject_channels=None
93
+ ):
94
+ super().__init__()
95
+ self.channels = channels
96
+ self.emb_channels = emb_channels
97
+ self.dropout = dropout
98
+ self.out_channels = out_channels or channels
99
+ self.use_checkpoint = use_checkpoint
100
+ self.norm_in = EnhancedFDN(channels, inject_channels)
101
+ self.norm_out = EnhancedFDN(self.out_channels, inject_channels)
102
+
103
+ self.in_layers = nn.Sequential(
104
+ nn.Identity(),
105
+ nn.SiLU(),
106
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
107
+ )
108
+
109
+ self.emb_layers = nn.Sequential(
110
+ nn.SiLU(),
111
+ linear(
112
+ emb_channels,
113
+ self.out_channels,
114
+ ),
115
+ )
116
+ self.out_layers = nn.Sequential(
117
+ nn.Identity(),
118
+ nn.SiLU(),
119
+ nn.Dropout(p=dropout),
120
+ zero_module(
121
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
122
+ ),
123
+ )
124
+
125
+ if self.out_channels == channels:
126
+ self.skip_connection = nn.Identity()
127
+ else:
128
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
129
+
130
+ def forward(self, x, emb, local_conditions):
131
+ return checkpoint(
132
+ self._forward, (x, emb, local_conditions), self.parameters(), self.use_checkpoint
133
+ )
134
+
135
+ def _forward(self, x, emb, local_conditions):
136
+ h = self.norm_in(x, local_conditions)
137
+ h = self.in_layers(h)
138
+
139
+ emb_out = self.emb_layers(emb).type(h.dtype)
140
+ while len(emb_out.shape) < len(h.shape):
141
+ emb_out = emb_out[..., None]
142
+
143
+ h = h + emb_out
144
+ h = self.norm_out(h, local_conditions)
145
+ h = self.out_layers(h)
146
+
147
+ return self.skip_connection(x) + h
148
+
149
+
150
+ class FeatureExtractor(nn.Module):
151
+ def __init__(self, local_channels, inject_channels, dims=2):
152
+ super().__init__()
153
+ self.pre_extractor = LocalTimestepEmbedSequential(
154
+ conv_nd(dims, local_channels, 32, 3, padding=1),
155
+ nn.SiLU(),
156
+ conv_nd(dims, 32, 64, 3, padding=1, stride=2),
157
+ nn.SiLU(),
158
+ conv_nd(dims, 64, 64, 3, padding=1),
159
+ nn.SiLU(),
160
+ conv_nd(dims, 64, 128, 3, padding=1, stride=2),
161
+ nn.SiLU(),
162
+ conv_nd(dims, 128, 128, 3, padding=1),
163
+ nn.SiLU(),
164
+ )
165
+ self.extractors = nn.ModuleList([
166
+ LocalTimestepEmbedSequential(
167
+ conv_nd(dims, 128, inject_channels[0], 3, padding=1, stride=2),
168
+ nn.SiLU()
169
+ ),
170
+ LocalTimestepEmbedSequential(
171
+ conv_nd(dims, inject_channels[0], inject_channels[1], 3, padding=1, stride=2),
172
+ nn.SiLU()
173
+ ),
174
+ LocalTimestepEmbedSequential(
175
+ conv_nd(dims, inject_channels[1], inject_channels[2], 3, padding=1, stride=2),
176
+ nn.SiLU()
177
+ ),
178
+ LocalTimestepEmbedSequential(
179
+ conv_nd(dims, inject_channels[2], inject_channels[3], 3, padding=1, stride=2),
180
+ nn.SiLU()
181
+ )
182
+ ])
183
+ self.zero_convs = nn.ModuleList([
184
+ zero_module(conv_nd(dims, inject_channels[0], inject_channels[0], 3, padding=1)),
185
+ zero_module(conv_nd(dims, inject_channels[1], inject_channels[1], 3, padding=1)),
186
+ zero_module(conv_nd(dims, inject_channels[2], inject_channels[2], 3, padding=1)),
187
+ zero_module(conv_nd(dims, inject_channels[3], inject_channels[3], 3, padding=1))
188
+ ])
189
+
190
+ def forward(self, local_conditions):
191
+ local_features = self.pre_extractor(local_conditions, None)
192
+ assert len(self.extractors) == len(self.zero_convs)
193
+
194
+ output_features = []
195
+ for idx in range(len(self.extractors)):
196
+ local_features = self.extractors[idx](local_features, None)
197
+ output_features.append(self.zero_convs[idx](local_features))
198
+ return output_features
199
+
200
+
201
+ class LocalAdapter(nn.Module):
202
+ def __init__(
203
+ self,
204
+ in_channels,
205
+ model_channels,
206
+ local_channels,
207
+ inject_channels,
208
+ inject_layers,
209
+ num_res_blocks,
210
+ attention_resolutions,
211
+ dropout=0,
212
+ channel_mult=(1, 2, 4, 8),
213
+ conv_resample=True,
214
+ dims=2,
215
+ use_checkpoint=False,
216
+ use_fp16=False,
217
+ num_heads=-1,
218
+ num_head_channels=-1,
219
+ num_heads_upsample=-1,
220
+ use_scale_shift_norm=False,
221
+ resblock_updown=False,
222
+ use_new_attention_order=False,
223
+ use_spatial_transformer=False,
224
+ transformer_depth=1,
225
+ context_dim=None,
226
+ n_embed=None,
227
+ legacy=True,
228
+ disable_self_attentions=None,
229
+ num_attention_blocks=None,
230
+ disable_middle_self_attn=False,
231
+ use_linear_in_transformer=False,
232
+ ):
233
+ super().__init__()
234
+
235
+ if context_dim is not None:
236
+ from omegaconf.listconfig import ListConfig
237
+ if type(context_dim) == ListConfig:
238
+ context_dim = list(context_dim)
239
+
240
+ if num_heads_upsample == -1:
241
+ num_heads_upsample = num_heads
242
+
243
+ self.dims = dims
244
+ self.in_channels = in_channels
245
+ self.model_channels = model_channels
246
+ self.inject_layers = inject_layers
247
+ if isinstance(num_res_blocks, int):
248
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
249
+ else:
250
+ if len(num_res_blocks) != len(channel_mult):
251
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
252
+ "as a list/tuple (per-level) with the same length as channel_mult")
253
+ self.num_res_blocks = num_res_blocks
254
+ if disable_self_attentions is not None:
255
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
256
+ assert len(disable_self_attentions) == len(channel_mult)
257
+ if num_attention_blocks is not None:
258
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
259
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
260
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
261
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
262
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
263
+ f"attention will still not be set.")
264
+
265
+ self.attention_resolutions = attention_resolutions
266
+ self.dropout = dropout
267
+ self.channel_mult = channel_mult
268
+ self.conv_resample = conv_resample
269
+ self.use_checkpoint = use_checkpoint
270
+ self.dtype = th.float16 if use_fp16 else th.float32
271
+ self.num_heads = num_heads
272
+ self.num_head_channels = num_head_channels
273
+ self.num_heads_upsample = num_heads_upsample
274
+ self.predict_codebook_ids = n_embed is not None
275
+
276
+ time_embed_dim = model_channels * 4
277
+ self.time_embed = nn.Sequential(
278
+ linear(model_channels, time_embed_dim),
279
+ nn.SiLU(),
280
+ linear(time_embed_dim, time_embed_dim),
281
+ )
282
+
283
+ self.feature_extractor = FeatureExtractor(local_channels, inject_channels)
284
+ self.input_blocks = nn.ModuleList(
285
+ [
286
+ LocalTimestepEmbedSequential(
287
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
288
+ )
289
+ ]
290
+ )
291
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
292
+
293
+ self._feature_size = model_channels
294
+ input_block_chans = [model_channels]
295
+ ch = model_channels
296
+ ds = 1
297
+ for level, mult in enumerate(channel_mult):
298
+ for nr in range(self.num_res_blocks[level]):
299
+ if (1 + 3*level + nr) in self.inject_layers:
300
+ layers = [
301
+ LocalResBlock(
302
+ ch,
303
+ time_embed_dim,
304
+ dropout,
305
+ out_channels=mult * model_channels,
306
+ dims=dims,
307
+ use_checkpoint=use_checkpoint,
308
+ inject_channels=inject_channels[level]
309
+ )
310
+ ]
311
+ else:
312
+ layers = [
313
+ ResBlock(
314
+ ch,
315
+ time_embed_dim,
316
+ dropout,
317
+ out_channels=mult * model_channels,
318
+ dims=dims,
319
+ use_checkpoint=use_checkpoint,
320
+ use_scale_shift_norm=use_scale_shift_norm,
321
+ )
322
+ ]
323
+ ch = mult * model_channels
324
+ if ds in attention_resolutions:
325
+ if num_head_channels == -1:
326
+ dim_head = ch // num_heads
327
+ else:
328
+ num_heads = ch // num_head_channels
329
+ dim_head = num_head_channels
330
+ if legacy:
331
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
332
+ if exists(disable_self_attentions):
333
+ disabled_sa = disable_self_attentions[level]
334
+ else:
335
+ disabled_sa = False
336
+
337
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
338
+ layers.append(
339
+ AttentionBlock(
340
+ ch,
341
+ use_checkpoint=use_checkpoint,
342
+ num_heads=num_heads,
343
+ num_head_channels=dim_head,
344
+ use_new_attention_order=use_new_attention_order,
345
+ ) if not use_spatial_transformer else SpatialTransformer(
346
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
347
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
348
+ use_checkpoint=use_checkpoint
349
+ )
350
+ )
351
+ self.input_blocks.append(LocalTimestepEmbedSequential(*layers))
352
+ self.zero_convs.append(self.make_zero_conv(ch))
353
+ self._feature_size += ch
354
+ input_block_chans.append(ch)
355
+ if level != len(channel_mult) - 1:
356
+ out_ch = ch
357
+ self.input_blocks.append(
358
+ LocalTimestepEmbedSequential(
359
+ ResBlock(
360
+ ch,
361
+ time_embed_dim,
362
+ dropout,
363
+ out_channels=out_ch,
364
+ dims=dims,
365
+ use_checkpoint=use_checkpoint,
366
+ use_scale_shift_norm=use_scale_shift_norm,
367
+ down=True,
368
+ )
369
+ if resblock_updown
370
+ else Downsample(
371
+ ch, conv_resample, dims=dims, out_channels=out_ch
372
+ )
373
+ )
374
+ )
375
+ ch = out_ch
376
+ input_block_chans.append(ch)
377
+ self.zero_convs.append(self.make_zero_conv(ch))
378
+ ds *= 2
379
+ self._feature_size += ch
380
+
381
+ if num_head_channels == -1:
382
+ dim_head = ch // num_heads
383
+ else:
384
+ num_heads = ch // num_head_channels
385
+ dim_head = num_head_channels
386
+ if legacy:
387
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
388
+ self.middle_block = LocalTimestepEmbedSequential(
389
+ ResBlock(
390
+ ch,
391
+ time_embed_dim,
392
+ dropout,
393
+ dims=dims,
394
+ use_checkpoint=use_checkpoint,
395
+ use_scale_shift_norm=use_scale_shift_norm,
396
+ ),
397
+ AttentionBlock(
398
+ ch,
399
+ use_checkpoint=use_checkpoint,
400
+ num_heads=num_heads,
401
+ num_head_channels=dim_head,
402
+ use_new_attention_order=use_new_attention_order,
403
+ ) if not use_spatial_transformer else SpatialTransformer(
404
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
405
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
406
+ use_checkpoint=use_checkpoint
407
+ ),
408
+ ResBlock(
409
+ ch,
410
+ time_embed_dim,
411
+ dropout,
412
+ dims=dims,
413
+ use_checkpoint=use_checkpoint,
414
+ use_scale_shift_norm=use_scale_shift_norm,
415
+ ),
416
+ )
417
+ self.middle_block_out = self.make_zero_conv(ch)
418
+ self._feature_size += ch
419
+
420
+ def make_zero_conv(self, channels):
421
+ return LocalTimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
422
+
423
+ def forward(self, x, timesteps, context, local_conditions, **kwargs):
424
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
425
+ emb = self.time_embed(t_emb)
426
+ local_features = self.feature_extractor(local_conditions)
427
+
428
+ outs = []
429
+ h = x.type(self.dtype)
430
+ for layer_idx, (module, zero_conv) in enumerate(zip(self.input_blocks, self.zero_convs)):
431
+ if layer_idx in self.inject_layers:
432
+ h = module(h, emb, context, local_features[self.inject_layers.index(layer_idx)])
433
+ else:
434
+ h = module(h, emb, context)
435
+ outs.append(zero_conv(h, emb, context))
436
+
437
+ h = self.middle_block(h, emb, context)
438
+ outs.append(self.middle_block_out(h, emb, context))
439
+
440
+ return outs
441
+
442
+
443
+ class LocalControlUNetModel(UNetModel):
444
+ def forward(self, x, timesteps=None, metadata=None,context=None, local_control=None,meta=False, **kwargs):
445
+ hs = []
446
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
447
+ emb = self.time_embed(t_emb)+metadata
448
+ h = x.type(self.dtype)
449
+ for module in self.input_blocks:
450
+ h = module(h, emb, context)
451
+ hs.append(h)
452
+ h = self.middle_block(h, emb, context)
453
+
454
+ h += local_control.pop()
455
+
456
+ for module in self.output_blocks:
457
+ h = torch.cat([h, hs.pop() + local_control.pop()], dim=1)
458
+ h = module(h, emb, context)
459
+
460
+ h = h.type(x.dtype)
461
+ return self.out(h)
crs_core/metadata_embedding.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from einops import rearrange
3
+ import torch
4
+ import numpy
5
+
6
+ from crs_core.modules.diffusionmodules.util import SinusoidalEmbedding,create_condition_vector
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ import torch.nn as nn
12
+
13
+ class MetadataMLP(nn.Module):
14
+ def __init__(self, input_dim, embedding_dim):
15
+ super(MetadataMLP, self).__init__()
16
+ self.fc1 = nn.Linear(input_dim, embedding_dim)
17
+ # self.activation = nn.SiLU()
18
+ # self.fc2=nn.Linear(embedding_dim, embedding_dim)
19
+
20
+ def forward(self, x):
21
+ out = self.fc1(x)
22
+ # out = self.activation(out)
23
+ # out = self.fc2(out)
24
+ return out
25
+
26
+ class metadata_embeddings(nn.Module):
27
+ def __init__(self, max_value,embedding_dim,max_period,metadata_dim):
28
+ super().__init__()
29
+ self.sinusoidal_embedding = SinusoidalEmbedding(max_value, embedding_dim)
30
+ self.mlp_models = nn.ModuleList([MetadataMLP(embedding_dim, embedding_dim*4) for _ in range(metadata_dim)])
31
+ self.max_period = max_period
32
+ self.embedding_dim = embedding_dim
33
+ self.metadata_dim = metadata_dim
34
+ self.max_value=max_value
35
+
36
+
37
+ def forward(self, metadata=None):
38
+ while len(metadata)==1:
39
+ metadata=metadata[0]
40
+ if metadata.dim()==1:
41
+ metadata=metadata.unsqueeze(0)
42
+ embedded_metadata = self.sinusoidal_embedding(metadata)
43
+ condition_vector = create_condition_vector(embedded_metadata, self.mlp_models, self.embedding_dim)
44
+ return condition_vector
crs_core/modules/__init__.py ADDED
File without changes
crs_core/modules/attention.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn, einsum
6
+ from einops import rearrange, repeat
7
+ from typing import Optional, Any
8
+
9
+ from crs_core.modules.diffusionmodules.util import checkpoint
10
+
11
+
12
+ try:
13
+ import xformers
14
+ import xformers.ops
15
+ XFORMERS_IS_AVAILBLE = True
16
+ except:
17
+ XFORMERS_IS_AVAILBLE = False
18
+
19
+ # CrossAttn precision handling
20
+ import os
21
+ _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
22
+
23
+ def exists(val):
24
+ return val is not None
25
+
26
+
27
+ def uniq(arr):
28
+ return{el: True for el in arr}.keys()
29
+
30
+
31
+ def default(val, d):
32
+ if exists(val):
33
+ return val
34
+ return d() if isfunction(d) else d
35
+
36
+
37
+ def max_neg_value(t):
38
+ return -torch.finfo(t.dtype).max
39
+
40
+
41
+ def init_(tensor):
42
+ dim = tensor.shape[-1]
43
+ std = 1 / math.sqrt(dim)
44
+ tensor.uniform_(-std, std)
45
+ return tensor
46
+
47
+
48
+ # feedforward
49
+ class GEGLU(nn.Module):
50
+ def __init__(self, dim_in, dim_out):
51
+ super().__init__()
52
+ self.proj = nn.Linear(dim_in, dim_out * 2)
53
+
54
+ def forward(self, x):
55
+ x, gate = self.proj(x).chunk(2, dim=-1)
56
+ return x * F.gelu(gate)
57
+
58
+
59
+ class FeedForward(nn.Module):
60
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
61
+ super().__init__()
62
+ inner_dim = int(dim * mult)
63
+ dim_out = default(dim_out, dim)
64
+ project_in = nn.Sequential(
65
+ nn.Linear(dim, inner_dim),
66
+ nn.GELU()
67
+ ) if not glu else GEGLU(dim, inner_dim)
68
+
69
+ self.net = nn.Sequential(
70
+ project_in,
71
+ nn.Dropout(dropout),
72
+ nn.Linear(inner_dim, dim_out)
73
+ )
74
+
75
+ def forward(self, x):
76
+ # print(x.shape)
77
+ return self.net(x)
78
+
79
+
80
+ def zero_module(module):
81
+ """
82
+ Zero out the parameters of a module and return it.
83
+ """
84
+ for p in module.parameters():
85
+ p.detach().zero_()
86
+ return module
87
+
88
+
89
+ def Normalize(in_channels):
90
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
91
+
92
+
93
+ class SpatialSelfAttention(nn.Module):
94
+ def __init__(self, in_channels):
95
+ super().__init__()
96
+ self.in_channels = in_channels
97
+
98
+ self.norm = Normalize(in_channels)
99
+ self.q = torch.nn.Conv2d(in_channels,
100
+ in_channels,
101
+ kernel_size=1,
102
+ stride=1,
103
+ padding=0)
104
+ self.k = torch.nn.Conv2d(in_channels,
105
+ in_channels,
106
+ kernel_size=1,
107
+ stride=1,
108
+ padding=0)
109
+ self.v = torch.nn.Conv2d(in_channels,
110
+ in_channels,
111
+ kernel_size=1,
112
+ stride=1,
113
+ padding=0)
114
+ self.proj_out = torch.nn.Conv2d(in_channels,
115
+ in_channels,
116
+ kernel_size=1,
117
+ stride=1,
118
+ padding=0)
119
+
120
+ def forward(self, x):
121
+ h_ = x
122
+ h_ = self.norm(h_)
123
+ q = self.q(h_)
124
+ k = self.k(h_)
125
+ v = self.v(h_)
126
+
127
+ # compute attention
128
+ b,c,h,w = q.shape
129
+ q = rearrange(q, 'b c h w -> b (h w) c')
130
+ k = rearrange(k, 'b c h w -> b c (h w)')
131
+ w_ = torch.einsum('bij,bjk->bik', q, k)
132
+
133
+ w_ = w_ * (int(c)**(-0.5))
134
+ w_ = torch.nn.functional.softmax(w_, dim=2)
135
+
136
+ # attend to values
137
+ v = rearrange(v, 'b c h w -> b c (h w)')
138
+ w_ = rearrange(w_, 'b i j -> b j i')
139
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
140
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
141
+ h_ = self.proj_out(h_)
142
+
143
+ return x+h_
144
+
145
+
146
+ class CrossAttention(nn.Module):
147
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
148
+ super().__init__()
149
+ inner_dim = dim_head * heads
150
+ context_dim = default(context_dim, query_dim)
151
+
152
+ self.scale = dim_head ** -0.5
153
+ self.heads = heads
154
+
155
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
156
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
157
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
158
+
159
+ self.to_out = nn.Sequential(
160
+ nn.Linear(inner_dim, query_dim),
161
+ nn.Dropout(dropout)
162
+ )
163
+
164
+ def forward(self, x, context=None, mask=None):
165
+ h = self.heads
166
+
167
+ q = self.to_q(x)
168
+ context = default(context, x)
169
+ k = self.to_k(context)
170
+ v = self.to_v(context)
171
+
172
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
173
+
174
+ # force cast to fp32 to avoid overflowing
175
+ if _ATTN_PRECISION =="fp32":
176
+ with torch.autocast(enabled=False, device_type = 'cuda'):
177
+ q, k = q.float(), k.float()
178
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
179
+ else:
180
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
181
+
182
+ del q, k
183
+
184
+ if exists(mask):
185
+ mask = rearrange(mask, 'b ... -> b (...)')
186
+ max_neg_value = -torch.finfo(sim.dtype).max
187
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
188
+ sim.masked_fill_(~mask, max_neg_value)
189
+
190
+ # attention, what we cannot get enough of
191
+ sim = sim.softmax(dim=-1)
192
+
193
+ out = einsum('b i j, b j d -> b i d', sim, v)
194
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
195
+ return self.to_out(out)
196
+
197
+
198
+ class MemoryEfficientCrossAttention(nn.Module):
199
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
200
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
201
+ super().__init__()
202
+ print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
203
+ f"{heads} heads.")
204
+ inner_dim = dim_head * heads
205
+ context_dim = default(context_dim, query_dim)
206
+
207
+ self.heads = heads
208
+ self.dim_head = dim_head
209
+
210
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
211
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
212
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
213
+
214
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
215
+ self.attention_op: Optional[Any] = None
216
+
217
+ def forward(self, x, context=None, mask=None):
218
+ q = self.to_q(x)
219
+ context = default(context, x)
220
+ k = self.to_k(context)
221
+ v = self.to_v(context)
222
+
223
+ b, _, _ = q.shape
224
+ q, k, v = map(
225
+ lambda t: t.unsqueeze(3)
226
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
227
+ .permute(0, 2, 1, 3)
228
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
229
+ .contiguous(),
230
+ (q, k, v),
231
+ )
232
+
233
+ # actually compute the attention, what we cannot get enough of
234
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
235
+
236
+ if exists(mask):
237
+ raise NotImplementedError
238
+ out = (
239
+ out.unsqueeze(0)
240
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
241
+ .permute(0, 2, 1, 3)
242
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
243
+ )
244
+ return self.to_out(out)
245
+
246
+
247
+ class BasicTransformerBlock(nn.Module):
248
+ ATTENTION_MODES = {
249
+ "softmax": CrossAttention, # vanilla attention
250
+ "softmax-xformers": MemoryEfficientCrossAttention
251
+ }
252
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
253
+ disable_self_attn=False):
254
+ super().__init__()
255
+ attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
256
+ assert attn_mode in self.ATTENTION_MODES
257
+ attn_cls = self.ATTENTION_MODES[attn_mode]
258
+ self.disable_self_attn = disable_self_attn
259
+ self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
260
+ context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
261
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
262
+ self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
263
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
264
+ self.norm1 = nn.LayerNorm(dim)
265
+ self.norm2 = nn.LayerNorm(dim)
266
+ self.norm3 = nn.LayerNorm(dim)
267
+ self.checkpoint = checkpoint
268
+
269
+ def forward(self, x, context=None):
270
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
271
+
272
+ def _forward(self, x, context=None):
273
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
274
+ x = self.attn2(self.norm2(x), context=context) + x
275
+ x = self.ff(self.norm3(x)) + x
276
+ return x
277
+
278
+
279
+ class SpatialTransformer(nn.Module):
280
+ """
281
+ Transformer block for image-like data.
282
+ First, project the input (aka embedding)
283
+ and reshape to b, t, d.
284
+ Then apply standard transformer action.
285
+ Finally, reshape to image
286
+ NEW: use_linear for more efficiency instead of the 1x1 convs
287
+ """
288
+ def __init__(self, in_channels, n_heads, d_head,
289
+ depth=1, dropout=0., context_dim=None,
290
+ disable_self_attn=False, use_linear=False,
291
+ use_checkpoint=True):
292
+ super().__init__()
293
+ if exists(context_dim) and not isinstance(context_dim, list):
294
+ context_dim = [context_dim]
295
+ self.in_channels = in_channels
296
+ inner_dim = n_heads * d_head
297
+ self.norm = Normalize(in_channels)
298
+ if not use_linear:
299
+ self.proj_in = nn.Conv2d(in_channels,
300
+ inner_dim,
301
+ kernel_size=1,
302
+ stride=1,
303
+ padding=0)
304
+ else:
305
+ self.proj_in = nn.Linear(in_channels, inner_dim)
306
+
307
+ self.transformer_blocks = nn.ModuleList(
308
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
309
+ disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
310
+ for d in range(depth)]
311
+ )
312
+ if not use_linear:
313
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
314
+ in_channels,
315
+ kernel_size=1,
316
+ stride=1,
317
+ padding=0))
318
+ else:
319
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
320
+ self.use_linear = use_linear
321
+
322
+ def forward(self, x, context=None):
323
+ # note: if no context is given, cross-attention defaults to self-attention
324
+ if not isinstance(context, list):
325
+ context = [context]
326
+ b, c, h, w = x.shape
327
+ x_in = x
328
+ x = self.norm(x)
329
+ if not self.use_linear:
330
+ x = self.proj_in(x)
331
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
332
+ if self.use_linear:
333
+ x = self.proj_in(x)
334
+ for i, block in enumerate(self.transformer_blocks):
335
+ x = block(x, context=context[i])
336
+ if self.use_linear:
337
+ x = self.proj_out(x)
338
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
339
+ if not self.use_linear:
340
+ x = self.proj_out(x)
341
+ return x + x_in
crs_core/modules/diffusionmodules/__init__.py ADDED
File without changes
crs_core/modules/diffusionmodules/model.py ADDED
@@ -0,0 +1,853 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from einops import rearrange
7
+ from typing import Optional, Any
8
+
9
+ from crs_core.modules.attention import MemoryEfficientCrossAttention
10
+
11
+ try:
12
+ import xformers
13
+ import xformers.ops
14
+ XFORMERS_IS_AVAILBLE = True
15
+ except:
16
+ XFORMERS_IS_AVAILBLE = False
17
+ print("No module 'xformers'. Proceeding without it.")
18
+
19
+
20
+ def get_timestep_embedding(timesteps, embedding_dim):
21
+ """
22
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
23
+ From Fairseq.
24
+ Build sinusoidal embeddings.
25
+ This matches the implementation in tensor2tensor, but differs slightly
26
+ from the description in Section 3.5 of "Attention Is All You Need".
27
+ """
28
+ assert len(timesteps.shape) == 1
29
+
30
+ half_dim = embedding_dim // 2
31
+ emb = math.log(10000) / (half_dim - 1)
32
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
33
+ emb = emb.to(device=timesteps.device)
34
+ emb = timesteps.float()[:, None] * emb[None, :]
35
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
36
+ if embedding_dim % 2 == 1: # zero pad
37
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
38
+ return emb
39
+
40
+
41
+ def nonlinearity(x):
42
+ # swish
43
+ return x*torch.sigmoid(x)
44
+
45
+
46
+ def Normalize(in_channels, num_groups=32):
47
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
48
+
49
+
50
+ class Upsample(nn.Module):
51
+ def __init__(self, in_channels, with_conv):
52
+ super().__init__()
53
+ self.with_conv = with_conv
54
+ if self.with_conv:
55
+ self.conv = torch.nn.Conv2d(in_channels,
56
+ in_channels,
57
+ kernel_size=3,
58
+ stride=1,
59
+ padding=1)
60
+
61
+ def forward(self, x):
62
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
63
+ if self.with_conv:
64
+ x = self.conv(x)
65
+ return x
66
+
67
+
68
+ class Downsample(nn.Module):
69
+ def __init__(self, in_channels, with_conv):
70
+ super().__init__()
71
+ self.with_conv = with_conv
72
+ if self.with_conv:
73
+ # no asymmetric padding in torch conv, must do it ourselves
74
+ self.conv = torch.nn.Conv2d(in_channels,
75
+ in_channels,
76
+ kernel_size=3,
77
+ stride=2,
78
+ padding=0)
79
+
80
+ def forward(self, x):
81
+ if self.with_conv:
82
+ pad = (0,1,0,1)
83
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
84
+ x = self.conv(x)
85
+ else:
86
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
87
+ return x
88
+
89
+
90
+ class ResnetBlock(nn.Module):
91
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
92
+ dropout, temb_channels=512):
93
+ super().__init__()
94
+ self.in_channels = in_channels
95
+ out_channels = in_channels if out_channels is None else out_channels
96
+ self.out_channels = out_channels
97
+ self.use_conv_shortcut = conv_shortcut
98
+
99
+ self.norm1 = Normalize(in_channels)
100
+ self.conv1 = torch.nn.Conv2d(in_channels,
101
+ out_channels,
102
+ kernel_size=3,
103
+ stride=1,
104
+ padding=1)
105
+ if temb_channels > 0:
106
+ self.temb_proj = torch.nn.Linear(temb_channels,
107
+ out_channels)
108
+ self.norm2 = Normalize(out_channels)
109
+ self.dropout = torch.nn.Dropout(dropout)
110
+ self.conv2 = torch.nn.Conv2d(out_channels,
111
+ out_channels,
112
+ kernel_size=3,
113
+ stride=1,
114
+ padding=1)
115
+ if self.in_channels != self.out_channels:
116
+ if self.use_conv_shortcut:
117
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
118
+ out_channels,
119
+ kernel_size=3,
120
+ stride=1,
121
+ padding=1)
122
+ else:
123
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
124
+ out_channels,
125
+ kernel_size=1,
126
+ stride=1,
127
+ padding=0)
128
+
129
+ def forward(self, x, temb):
130
+ h = x
131
+ h = self.norm1(h)
132
+ h = nonlinearity(h)
133
+ h = self.conv1(h)
134
+
135
+ if temb is not None:
136
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
137
+
138
+ h = self.norm2(h)
139
+ h = nonlinearity(h)
140
+ h = self.dropout(h)
141
+ h = self.conv2(h)
142
+
143
+ if self.in_channels != self.out_channels:
144
+ if self.use_conv_shortcut:
145
+ x = self.conv_shortcut(x)
146
+ else:
147
+ x = self.nin_shortcut(x)
148
+
149
+ return x+h
150
+
151
+
152
+ class AttnBlock(nn.Module):
153
+ def __init__(self, in_channels):
154
+ super().__init__()
155
+ self.in_channels = in_channels
156
+
157
+ self.norm = Normalize(in_channels)
158
+ self.q = torch.nn.Conv2d(in_channels,
159
+ in_channels,
160
+ kernel_size=1,
161
+ stride=1,
162
+ padding=0)
163
+ self.k = torch.nn.Conv2d(in_channels,
164
+ in_channels,
165
+ kernel_size=1,
166
+ stride=1,
167
+ padding=0)
168
+ self.v = torch.nn.Conv2d(in_channels,
169
+ in_channels,
170
+ kernel_size=1,
171
+ stride=1,
172
+ padding=0)
173
+ self.proj_out = torch.nn.Conv2d(in_channels,
174
+ in_channels,
175
+ kernel_size=1,
176
+ stride=1,
177
+ padding=0)
178
+
179
+ def forward(self, x):
180
+ h_ = x
181
+ h_ = self.norm(h_)
182
+ q = self.q(h_)
183
+ k = self.k(h_)
184
+ v = self.v(h_)
185
+
186
+ # compute attention
187
+ b,c,h,w = q.shape
188
+ q = q.reshape(b,c,h*w)
189
+ q = q.permute(0,2,1) # b,hw,c
190
+ k = k.reshape(b,c,h*w) # b,c,hw
191
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
192
+ w_ = w_ * (int(c)**(-0.5))
193
+ w_ = torch.nn.functional.softmax(w_, dim=2)
194
+
195
+ # attend to values
196
+ v = v.reshape(b,c,h*w)
197
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
198
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
199
+ h_ = h_.reshape(b,c,h,w)
200
+
201
+ h_ = self.proj_out(h_)
202
+
203
+ return x+h_
204
+
205
+ class MemoryEfficientAttnBlock(nn.Module):
206
+ """
207
+ Uses xformers efficient implementation,
208
+ see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
209
+ Note: this is a single-head self-attention operation
210
+ """
211
+ #
212
+ def __init__(self, in_channels):
213
+ super().__init__()
214
+ self.in_channels = in_channels
215
+
216
+ self.norm = Normalize(in_channels)
217
+ self.q = torch.nn.Conv2d(in_channels,
218
+ in_channels,
219
+ kernel_size=1,
220
+ stride=1,
221
+ padding=0)
222
+ self.k = torch.nn.Conv2d(in_channels,
223
+ in_channels,
224
+ kernel_size=1,
225
+ stride=1,
226
+ padding=0)
227
+ self.v = torch.nn.Conv2d(in_channels,
228
+ in_channels,
229
+ kernel_size=1,
230
+ stride=1,
231
+ padding=0)
232
+ self.proj_out = torch.nn.Conv2d(in_channels,
233
+ in_channels,
234
+ kernel_size=1,
235
+ stride=1,
236
+ padding=0)
237
+ self.attention_op: Optional[Any] = None
238
+
239
+ def forward(self, x):
240
+ h_ = x
241
+ h_ = self.norm(h_)
242
+ q = self.q(h_)
243
+ k = self.k(h_)
244
+ v = self.v(h_)
245
+
246
+ # compute attention
247
+ B, C, H, W = q.shape
248
+ q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
249
+
250
+ q, k, v = map(
251
+ lambda t: t.unsqueeze(3)
252
+ .reshape(B, t.shape[1], 1, C)
253
+ .permute(0, 2, 1, 3)
254
+ .reshape(B * 1, t.shape[1], C)
255
+ .contiguous(),
256
+ (q, k, v),
257
+ )
258
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
259
+
260
+ out = (
261
+ out.unsqueeze(0)
262
+ .reshape(B, 1, out.shape[1], C)
263
+ .permute(0, 2, 1, 3)
264
+ .reshape(B, out.shape[1], C)
265
+ )
266
+ out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
267
+ out = self.proj_out(out)
268
+ return x+out
269
+
270
+
271
+ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
272
+ def forward(self, x, context=None, mask=None):
273
+ b, c, h, w = x.shape
274
+ x = rearrange(x, 'b c h w -> b (h w) c')
275
+ out = super().forward(x, context=context, mask=mask)
276
+ out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
277
+ return x + out
278
+
279
+
280
+ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
281
+ assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
282
+ if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
283
+ attn_type = "vanilla-xformers"
284
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
285
+ if attn_type == "vanilla":
286
+ assert attn_kwargs is None
287
+ return AttnBlock(in_channels)
288
+ elif attn_type == "vanilla-xformers":
289
+ print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
290
+ return MemoryEfficientAttnBlock(in_channels)
291
+ elif type == "memory-efficient-cross-attn":
292
+ attn_kwargs["query_dim"] = in_channels
293
+ return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
294
+ elif attn_type == "none":
295
+ return nn.Identity(in_channels)
296
+ else:
297
+ raise NotImplementedError()
298
+
299
+
300
+ class Model(nn.Module):
301
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
302
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
303
+ resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
304
+ super().__init__()
305
+ if use_linear_attn: attn_type = "linear"
306
+ self.ch = ch
307
+ self.temb_ch = self.ch*4
308
+ self.num_resolutions = len(ch_mult)
309
+ self.num_res_blocks = num_res_blocks
310
+ self.resolution = resolution
311
+ self.in_channels = in_channels
312
+
313
+ self.use_timestep = use_timestep
314
+ if self.use_timestep:
315
+ # timestep embedding
316
+ self.temb = nn.Module()
317
+ self.temb.dense = nn.ModuleList([
318
+ torch.nn.Linear(self.ch,
319
+ self.temb_ch),
320
+ torch.nn.Linear(self.temb_ch,
321
+ self.temb_ch),
322
+ ])
323
+
324
+ # downsampling
325
+ self.conv_in = torch.nn.Conv2d(in_channels,
326
+ self.ch,
327
+ kernel_size=3,
328
+ stride=1,
329
+ padding=1)
330
+
331
+ curr_res = resolution
332
+ in_ch_mult = (1,)+tuple(ch_mult)
333
+ self.down = nn.ModuleList()
334
+ for i_level in range(self.num_resolutions):
335
+ block = nn.ModuleList()
336
+ attn = nn.ModuleList()
337
+ block_in = ch*in_ch_mult[i_level]
338
+ block_out = ch*ch_mult[i_level]
339
+ for i_block in range(self.num_res_blocks):
340
+ block.append(ResnetBlock(in_channels=block_in,
341
+ out_channels=block_out,
342
+ temb_channels=self.temb_ch,
343
+ dropout=dropout))
344
+ block_in = block_out
345
+ if curr_res in attn_resolutions:
346
+ attn.append(make_attn(block_in, attn_type=attn_type))
347
+ down = nn.Module()
348
+ down.block = block
349
+ down.attn = attn
350
+ if i_level != self.num_resolutions-1:
351
+ down.downsample = Downsample(block_in, resamp_with_conv)
352
+ curr_res = curr_res // 2
353
+ self.down.append(down)
354
+
355
+ # middle
356
+ self.mid = nn.Module()
357
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
358
+ out_channels=block_in,
359
+ temb_channels=self.temb_ch,
360
+ dropout=dropout)
361
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
362
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
363
+ out_channels=block_in,
364
+ temb_channels=self.temb_ch,
365
+ dropout=dropout)
366
+
367
+ # upsampling
368
+ self.up = nn.ModuleList()
369
+ for i_level in reversed(range(self.num_resolutions)):
370
+ block = nn.ModuleList()
371
+ attn = nn.ModuleList()
372
+ block_out = ch*ch_mult[i_level]
373
+ skip_in = ch*ch_mult[i_level]
374
+ for i_block in range(self.num_res_blocks+1):
375
+ if i_block == self.num_res_blocks:
376
+ skip_in = ch*in_ch_mult[i_level]
377
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
378
+ out_channels=block_out,
379
+ temb_channels=self.temb_ch,
380
+ dropout=dropout))
381
+ block_in = block_out
382
+ if curr_res in attn_resolutions:
383
+ attn.append(make_attn(block_in, attn_type=attn_type))
384
+ up = nn.Module()
385
+ up.block = block
386
+ up.attn = attn
387
+ if i_level != 0:
388
+ up.upsample = Upsample(block_in, resamp_with_conv)
389
+ curr_res = curr_res * 2
390
+ self.up.insert(0, up) # prepend to get consistent order
391
+
392
+ # end
393
+ self.norm_out = Normalize(block_in)
394
+ self.conv_out = torch.nn.Conv2d(block_in,
395
+ out_ch,
396
+ kernel_size=3,
397
+ stride=1,
398
+ padding=1)
399
+
400
+ def forward(self, x, t=None, context=None):
401
+ #assert x.shape[2] == x.shape[3] == self.resolution
402
+ if context is not None:
403
+ # assume aligned context, cat along channel axis
404
+ x = torch.cat((x, context), dim=1)
405
+ if self.use_timestep:
406
+ # timestep embedding
407
+ assert t is not None
408
+ temb = get_timestep_embedding(t, self.ch)
409
+ temb = self.temb.dense[0](temb)
410
+ temb = nonlinearity(temb)
411
+ temb = self.temb.dense[1](temb)
412
+ # print(temb,"temb")
413
+ else:
414
+ temb = None
415
+
416
+ # downsampling
417
+ hs = [self.conv_in(x)]
418
+ for i_level in range(self.num_resolutions):
419
+ for i_block in range(self.num_res_blocks):
420
+ h = self.down[i_level].block[i_block](hs[-1], temb)
421
+ if len(self.down[i_level].attn) > 0:
422
+ h = self.down[i_level].attn[i_block](h)
423
+ hs.append(h)
424
+ if i_level != self.num_resolutions-1:
425
+ hs.append(self.down[i_level].downsample(hs[-1]))
426
+
427
+ # middle
428
+ h = hs[-1]
429
+ h = self.mid.block_1(h, temb)
430
+ h = self.mid.attn_1(h)
431
+ h = self.mid.block_2(h, temb)
432
+
433
+ # upsampling
434
+ for i_level in reversed(range(self.num_resolutions)):
435
+ for i_block in range(self.num_res_blocks+1):
436
+ h = self.up[i_level].block[i_block](
437
+ torch.cat([h, hs.pop()], dim=1), temb)
438
+ if len(self.up[i_level].attn) > 0:
439
+ h = self.up[i_level].attn[i_block](h)
440
+ if i_level != 0:
441
+ h = self.up[i_level].upsample(h)
442
+
443
+ # end
444
+ h = self.norm_out(h)
445
+ h = nonlinearity(h)
446
+ h = self.conv_out(h)
447
+ return h
448
+
449
+ def get_last_layer(self):
450
+ return self.conv_out.weight
451
+
452
+
453
+ class Encoder(nn.Module):
454
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
455
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
456
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
457
+ **ignore_kwargs):
458
+ super().__init__()
459
+ if use_linear_attn: attn_type = "linear"
460
+ self.ch = ch
461
+ self.temb_ch = 0
462
+ self.num_resolutions = len(ch_mult)
463
+ self.num_res_blocks = num_res_blocks
464
+ self.resolution = resolution
465
+ self.in_channels = in_channels
466
+
467
+ # downsampling
468
+ self.conv_in = torch.nn.Conv2d(in_channels,
469
+ self.ch,
470
+ kernel_size=3,
471
+ stride=1,
472
+ padding=1)
473
+
474
+ curr_res = resolution
475
+ in_ch_mult = (1,)+tuple(ch_mult)
476
+ self.in_ch_mult = in_ch_mult
477
+ self.down = nn.ModuleList()
478
+ for i_level in range(self.num_resolutions):
479
+ block = nn.ModuleList()
480
+ attn = nn.ModuleList()
481
+ block_in = ch*in_ch_mult[i_level]
482
+ block_out = ch*ch_mult[i_level]
483
+ for i_block in range(self.num_res_blocks):
484
+ block.append(ResnetBlock(in_channels=block_in,
485
+ out_channels=block_out,
486
+ temb_channels=self.temb_ch,
487
+ dropout=dropout))
488
+ block_in = block_out
489
+ if curr_res in attn_resolutions:
490
+ attn.append(make_attn(block_in, attn_type=attn_type))
491
+ down = nn.Module()
492
+ down.block = block
493
+ down.attn = attn
494
+ if i_level != self.num_resolutions-1:
495
+ down.downsample = Downsample(block_in, resamp_with_conv)
496
+ curr_res = curr_res // 2
497
+ self.down.append(down)
498
+
499
+ # middle
500
+ self.mid = nn.Module()
501
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
502
+ out_channels=block_in,
503
+ temb_channels=self.temb_ch,
504
+ dropout=dropout)
505
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
506
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
507
+ out_channels=block_in,
508
+ temb_channels=self.temb_ch,
509
+ dropout=dropout)
510
+
511
+ # end
512
+ self.norm_out = Normalize(block_in)
513
+ self.conv_out = torch.nn.Conv2d(block_in,
514
+ 2*z_channels if double_z else z_channels,
515
+ kernel_size=3,
516
+ stride=1,
517
+ padding=1)
518
+
519
+ def forward(self, x):
520
+ # timestep embedding
521
+ temb = None
522
+
523
+ # downsampling
524
+ hs = [self.conv_in(x)]
525
+ for i_level in range(self.num_resolutions):
526
+ for i_block in range(self.num_res_blocks):
527
+ h = self.down[i_level].block[i_block](hs[-1], temb)
528
+ if len(self.down[i_level].attn) > 0:
529
+ h = self.down[i_level].attn[i_block](h)
530
+ hs.append(h)
531
+ if i_level != self.num_resolutions-1:
532
+ hs.append(self.down[i_level].downsample(hs[-1]))
533
+
534
+ # middle
535
+ h = hs[-1]
536
+ h = self.mid.block_1(h, temb)
537
+ h = self.mid.attn_1(h)
538
+ h = self.mid.block_2(h, temb)
539
+
540
+ # end
541
+ h = self.norm_out(h)
542
+ h = nonlinearity(h)
543
+ h = self.conv_out(h)
544
+ return h
545
+
546
+
547
+ class Decoder(nn.Module):
548
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
549
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
550
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
551
+ attn_type="vanilla", **ignorekwargs):
552
+ super().__init__()
553
+ if use_linear_attn: attn_type = "linear"
554
+ self.ch = ch
555
+ self.temb_ch = 0
556
+ self.num_resolutions = len(ch_mult)
557
+ self.num_res_blocks = num_res_blocks
558
+ self.resolution = resolution
559
+ self.in_channels = in_channels
560
+ self.give_pre_end = give_pre_end
561
+ self.tanh_out = tanh_out
562
+
563
+ # compute in_ch_mult, block_in and curr_res at lowest res
564
+ in_ch_mult = (1,)+tuple(ch_mult)
565
+ block_in = ch*ch_mult[self.num_resolutions-1]
566
+ curr_res = resolution // 2**(self.num_resolutions-1)
567
+ self.z_shape = (1,z_channels,curr_res,curr_res)
568
+ print("Working with z of shape {} = {} dimensions.".format(
569
+ self.z_shape, np.prod(self.z_shape)))
570
+
571
+ # z to block_in
572
+ self.conv_in = torch.nn.Conv2d(z_channels,
573
+ block_in,
574
+ kernel_size=3,
575
+ stride=1,
576
+ padding=1)
577
+
578
+ # middle
579
+ self.mid = nn.Module()
580
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
581
+ out_channels=block_in,
582
+ temb_channels=self.temb_ch,
583
+ dropout=dropout)
584
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
585
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
586
+ out_channels=block_in,
587
+ temb_channels=self.temb_ch,
588
+ dropout=dropout)
589
+
590
+ # upsampling
591
+ self.up = nn.ModuleList()
592
+ for i_level in reversed(range(self.num_resolutions)):
593
+ block = nn.ModuleList()
594
+ attn = nn.ModuleList()
595
+ block_out = ch*ch_mult[i_level]
596
+ for i_block in range(self.num_res_blocks+1):
597
+ block.append(ResnetBlock(in_channels=block_in,
598
+ out_channels=block_out,
599
+ temb_channels=self.temb_ch,
600
+ dropout=dropout))
601
+ block_in = block_out
602
+ if curr_res in attn_resolutions:
603
+ attn.append(make_attn(block_in, attn_type=attn_type))
604
+ up = nn.Module()
605
+ up.block = block
606
+ up.attn = attn
607
+ if i_level != 0:
608
+ up.upsample = Upsample(block_in, resamp_with_conv)
609
+ curr_res = curr_res * 2
610
+ self.up.insert(0, up) # prepend to get consistent order
611
+
612
+ # end
613
+ self.norm_out = Normalize(block_in)
614
+ self.conv_out = torch.nn.Conv2d(block_in,
615
+ out_ch,
616
+ kernel_size=3,
617
+ stride=1,
618
+ padding=1)
619
+
620
+ def forward(self, z):
621
+ #assert z.shape[1:] == self.z_shape[1:]
622
+ self.last_z_shape = z.shape
623
+
624
+ # timestep embedding
625
+ temb = None
626
+
627
+ # z to block_in
628
+ h = self.conv_in(z)
629
+
630
+ # middle
631
+ h = self.mid.block_1(h, temb)
632
+ h = self.mid.attn_1(h)
633
+ h = self.mid.block_2(h, temb)
634
+
635
+ # upsampling
636
+ for i_level in reversed(range(self.num_resolutions)):
637
+ for i_block in range(self.num_res_blocks+1):
638
+ h = self.up[i_level].block[i_block](h, temb)
639
+ if len(self.up[i_level].attn) > 0:
640
+ h = self.up[i_level].attn[i_block](h)
641
+ if i_level != 0:
642
+ h = self.up[i_level].upsample(h)
643
+
644
+ # end
645
+ if self.give_pre_end:
646
+ return h
647
+
648
+ h = self.norm_out(h)
649
+ h = nonlinearity(h)
650
+ h = self.conv_out(h)
651
+ if self.tanh_out:
652
+ h = torch.tanh(h)
653
+ return h
654
+
655
+
656
+ class SimpleDecoder(nn.Module):
657
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
658
+ super().__init__()
659
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
660
+ ResnetBlock(in_channels=in_channels,
661
+ out_channels=2 * in_channels,
662
+ temb_channels=0, dropout=0.0),
663
+ ResnetBlock(in_channels=2 * in_channels,
664
+ out_channels=4 * in_channels,
665
+ temb_channels=0, dropout=0.0),
666
+ ResnetBlock(in_channels=4 * in_channels,
667
+ out_channels=2 * in_channels,
668
+ temb_channels=0, dropout=0.0),
669
+ nn.Conv2d(2*in_channels, in_channels, 1),
670
+ Upsample(in_channels, with_conv=True)])
671
+ # end
672
+ self.norm_out = Normalize(in_channels)
673
+ self.conv_out = torch.nn.Conv2d(in_channels,
674
+ out_channels,
675
+ kernel_size=3,
676
+ stride=1,
677
+ padding=1)
678
+
679
+ def forward(self, x):
680
+ for i, layer in enumerate(self.model):
681
+ if i in [1,2,3]:
682
+ x = layer(x, None)
683
+ else:
684
+ x = layer(x)
685
+
686
+ h = self.norm_out(x)
687
+ h = nonlinearity(h)
688
+ x = self.conv_out(h)
689
+ return x
690
+
691
+
692
+ class UpsampleDecoder(nn.Module):
693
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
694
+ ch_mult=(2,2), dropout=0.0):
695
+ super().__init__()
696
+ # upsampling
697
+ self.temb_ch = 0
698
+ self.num_resolutions = len(ch_mult)
699
+ self.num_res_blocks = num_res_blocks
700
+ block_in = in_channels
701
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
702
+ self.res_blocks = nn.ModuleList()
703
+ self.upsample_blocks = nn.ModuleList()
704
+ for i_level in range(self.num_resolutions):
705
+ res_block = []
706
+ block_out = ch * ch_mult[i_level]
707
+ for i_block in range(self.num_res_blocks + 1):
708
+ res_block.append(ResnetBlock(in_channels=block_in,
709
+ out_channels=block_out,
710
+ temb_channels=self.temb_ch,
711
+ dropout=dropout))
712
+ block_in = block_out
713
+ self.res_blocks.append(nn.ModuleList(res_block))
714
+ if i_level != self.num_resolutions - 1:
715
+ self.upsample_blocks.append(Upsample(block_in, True))
716
+ curr_res = curr_res * 2
717
+
718
+ # end
719
+ self.norm_out = Normalize(block_in)
720
+ self.conv_out = torch.nn.Conv2d(block_in,
721
+ out_channels,
722
+ kernel_size=3,
723
+ stride=1,
724
+ padding=1)
725
+
726
+ def forward(self, x):
727
+ # upsampling
728
+ h = x
729
+ for k, i_level in enumerate(range(self.num_resolutions)):
730
+ for i_block in range(self.num_res_blocks + 1):
731
+ h = self.res_blocks[i_level][i_block](h, None)
732
+ if i_level != self.num_resolutions - 1:
733
+ h = self.upsample_blocks[k](h)
734
+ h = self.norm_out(h)
735
+ h = nonlinearity(h)
736
+ h = self.conv_out(h)
737
+ return h
738
+
739
+
740
+ class LatentRescaler(nn.Module):
741
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
742
+ super().__init__()
743
+ # residual block, interpolate, residual block
744
+ self.factor = factor
745
+ self.conv_in = nn.Conv2d(in_channels,
746
+ mid_channels,
747
+ kernel_size=3,
748
+ stride=1,
749
+ padding=1)
750
+ self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
751
+ out_channels=mid_channels,
752
+ temb_channels=0,
753
+ dropout=0.0) for _ in range(depth)])
754
+ self.attn = AttnBlock(mid_channels)
755
+ self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
756
+ out_channels=mid_channels,
757
+ temb_channels=0,
758
+ dropout=0.0) for _ in range(depth)])
759
+
760
+ self.conv_out = nn.Conv2d(mid_channels,
761
+ out_channels,
762
+ kernel_size=1,
763
+ )
764
+
765
+ def forward(self, x):
766
+ x = self.conv_in(x)
767
+ for block in self.res_block1:
768
+ x = block(x, None)
769
+ x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
770
+ x = self.attn(x)
771
+ for block in self.res_block2:
772
+ x = block(x, None)
773
+ x = self.conv_out(x)
774
+ return x
775
+
776
+
777
+ class MergedRescaleEncoder(nn.Module):
778
+ def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
779
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
780
+ ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
781
+ super().__init__()
782
+ intermediate_chn = ch * ch_mult[-1]
783
+ self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
784
+ z_channels=intermediate_chn, double_z=False, resolution=resolution,
785
+ attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
786
+ out_ch=None)
787
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
788
+ mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
789
+
790
+ def forward(self, x):
791
+ x = self.encoder(x)
792
+ x = self.rescaler(x)
793
+ return x
794
+
795
+
796
+ class MergedRescaleDecoder(nn.Module):
797
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
798
+ dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
799
+ super().__init__()
800
+ tmp_chn = z_channels*ch_mult[-1]
801
+ self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
802
+ resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
803
+ ch_mult=ch_mult, resolution=resolution, ch=ch)
804
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
805
+ out_channels=tmp_chn, depth=rescale_module_depth)
806
+
807
+ def forward(self, x):
808
+ x = self.rescaler(x)
809
+ x = self.decoder(x)
810
+ return x
811
+
812
+
813
+ class Upsampler(nn.Module):
814
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
815
+ super().__init__()
816
+ assert out_size >= in_size
817
+ num_blocks = int(np.log2(out_size//in_size))+1
818
+ factor_up = 1.+ (out_size % in_size)
819
+ print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
820
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
821
+ out_channels=in_channels)
822
+ self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
823
+ attn_resolutions=[], in_channels=None, ch=in_channels,
824
+ ch_mult=[ch_mult for _ in range(num_blocks)])
825
+
826
+ def forward(self, x):
827
+ x = self.rescaler(x)
828
+ x = self.decoder(x)
829
+ return x
830
+
831
+
832
+ class Resize(nn.Module):
833
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
834
+ super().__init__()
835
+ self.with_conv = learned
836
+ self.mode = mode
837
+ if self.with_conv:
838
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
839
+ raise NotImplementedError()
840
+ assert in_channels is not None
841
+ # no asymmetric padding in torch conv, must do it ourselves
842
+ self.conv = torch.nn.Conv2d(in_channels,
843
+ in_channels,
844
+ kernel_size=4,
845
+ stride=2,
846
+ padding=1)
847
+
848
+ def forward(self, x, scale_factor=1.0):
849
+ if scale_factor==1.0:
850
+ return x
851
+ else:
852
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
853
+ return x
crs_core/modules/diffusionmodules/openaimodel.py ADDED
@@ -0,0 +1,794 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ import math
3
+
4
+ import numpy as np
5
+ import torch as th
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from crs_core.modules.diffusionmodules.util import (
10
+ checkpoint,
11
+ conv_nd,
12
+ linear,
13
+ avg_pool_nd,
14
+ zero_module,
15
+ normalization,
16
+ timestep_embedding,
17
+ timestep_embedding_t,
18
+ )
19
+ from crs_core.modules.attention import SpatialTransformer
20
+ from crs_core.utils import exists
21
+
22
+
23
+ # dummy replace
24
+ def convert_module_to_f16(x):
25
+ pass
26
+
27
+ def convert_module_to_f32(x):
28
+ pass
29
+
30
+
31
+ ## go
32
+ class AttentionPool2d(nn.Module):
33
+ """
34
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ spacial_dim: int,
40
+ embed_dim: int,
41
+ num_heads_channels: int,
42
+ output_dim: int = None,
43
+ ):
44
+ super().__init__()
45
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
46
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
47
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
48
+ self.num_heads = embed_dim // num_heads_channels
49
+ self.attention = QKVAttention(self.num_heads)
50
+
51
+ def forward(self, x):
52
+ b, c, *_spatial = x.shape
53
+ x = x.reshape(b, c, -1) # NC(HW)
54
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
55
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
56
+ x = self.qkv_proj(x)
57
+ x = self.attention(x)
58
+ x = self.c_proj(x)
59
+ return x[:, :, 0]
60
+
61
+
62
+ class TimestepBlock(nn.Module):
63
+ """
64
+ Any module where forward() takes timestep embeddings as a second argument.
65
+ """
66
+
67
+ @abstractmethod
68
+ def forward(self, x, emb):
69
+ """
70
+ Apply the module to `x` given `emb` timestep embeddings.
71
+ """
72
+
73
+
74
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
75
+ """
76
+ A sequential module that passes timestep embeddings to the children that
77
+ support it as an extra input.
78
+ """
79
+
80
+ def forward(self, x, emb, context=None):
81
+ for layer in self:
82
+ if isinstance(layer, TimestepBlock):
83
+ x = layer(x, emb)
84
+ elif isinstance(layer, SpatialTransformer):
85
+ x = layer(x, context)
86
+ else:
87
+ x = layer(x)
88
+ return x
89
+
90
+
91
+ class Upsample(nn.Module):
92
+ """
93
+ An upsampling layer with an optional convolution.
94
+ :param channels: channels in the inputs and outputs.
95
+ :param use_conv: a bool determining if a convolution is applied.
96
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
97
+ upsampling occurs in the inner-two dimensions.
98
+ """
99
+
100
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
101
+ super().__init__()
102
+ self.channels = channels
103
+ self.out_channels = out_channels or channels
104
+ self.use_conv = use_conv
105
+ self.dims = dims
106
+ if use_conv:
107
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
108
+
109
+ def forward(self, x):
110
+ assert x.shape[1] == self.channels
111
+ if self.dims == 3:
112
+ x = F.interpolate(
113
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
114
+ )
115
+ else:
116
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
117
+ if self.use_conv:
118
+ x = self.conv(x)
119
+ return x
120
+
121
+ class TransposedUpsample(nn.Module):
122
+ 'Learned 2x upsampling without padding'
123
+ def __init__(self, channels, out_channels=None, ks=5):
124
+ super().__init__()
125
+ self.channels = channels
126
+ self.out_channels = out_channels or channels
127
+
128
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
129
+
130
+ def forward(self,x):
131
+ return self.up(x)
132
+
133
+
134
+ class Downsample(nn.Module):
135
+ """
136
+ A downsampling layer with an optional convolution.
137
+ :param channels: channels in the inputs and outputs.
138
+ :param use_conv: a bool determining if a convolution is applied.
139
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
140
+ downsampling occurs in the inner-two dimensions.
141
+ """
142
+
143
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
144
+ super().__init__()
145
+ self.channels = channels
146
+ self.out_channels = out_channels or channels
147
+ self.use_conv = use_conv
148
+ self.dims = dims
149
+ stride = 2 if dims != 3 else (1, 2, 2)
150
+ if use_conv:
151
+ self.op = conv_nd(
152
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
153
+ )
154
+ else:
155
+ assert self.channels == self.out_channels
156
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
157
+
158
+ def forward(self, x):
159
+ assert x.shape[1] == self.channels
160
+ return self.op(x)
161
+
162
+
163
+ class ResBlock(TimestepBlock):
164
+ """
165
+ A residual block that can optionally change the number of channels.
166
+ :param channels: the number of input channels.
167
+ :param emb_channels: the number of timestep embedding channels.
168
+ :param dropout: the rate of dropout.
169
+ :param out_channels: if specified, the number of out channels.
170
+ :param use_conv: if True and out_channels is specified, use a spatial
171
+ convolution instead of a smaller 1x1 convolution to change the
172
+ channels in the skip connection.
173
+ :param dims: determines if the signal is 1D, 2D, or 3D.
174
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
175
+ :param up: if True, use this block for upsampling.
176
+ :param down: if True, use this block for downsampling.
177
+ """
178
+
179
+ def __init__(
180
+ self,
181
+ channels,
182
+ emb_channels,
183
+ dropout,
184
+ out_channels=None,
185
+ use_conv=False,
186
+ use_scale_shift_norm=False,
187
+ dims=2,
188
+ use_checkpoint=False,
189
+ up=False,
190
+ down=False,
191
+ ):
192
+ super().__init__()
193
+ self.channels = channels
194
+ self.emb_channels = emb_channels
195
+ self.dropout = dropout
196
+ self.out_channels = out_channels or channels
197
+ self.use_conv = use_conv
198
+ self.use_checkpoint = use_checkpoint
199
+ self.use_scale_shift_norm = use_scale_shift_norm
200
+
201
+ self.in_layers = nn.Sequential(
202
+ normalization(channels),
203
+ nn.SiLU(),
204
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
205
+ )
206
+
207
+ self.updown = up or down
208
+
209
+ if up:
210
+ self.h_upd = Upsample(channels, False, dims)
211
+ self.x_upd = Upsample(channels, False, dims)
212
+ elif down:
213
+ self.h_upd = Downsample(channels, False, dims)
214
+ self.x_upd = Downsample(channels, False, dims)
215
+ else:
216
+ self.h_upd = self.x_upd = nn.Identity()
217
+
218
+ self.emb_layers = nn.Sequential(
219
+ nn.SiLU(),
220
+ linear(
221
+ emb_channels,
222
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
223
+ ),
224
+ )
225
+ self.out_layers = nn.Sequential(
226
+ normalization(self.out_channels),
227
+ nn.SiLU(),
228
+ nn.Dropout(p=dropout),
229
+ zero_module(
230
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
231
+ ),
232
+ )
233
+
234
+ if self.out_channels == channels:
235
+ self.skip_connection = nn.Identity()
236
+ elif use_conv:
237
+ self.skip_connection = conv_nd(
238
+ dims, channels, self.out_channels, 3, padding=1
239
+ )
240
+ else:
241
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
242
+
243
+ def forward(self, x, emb):
244
+ """
245
+ Apply the block to a Tensor, conditioned on a timestep embedding.
246
+ :param x: an [N x C x ...] Tensor of features.
247
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
248
+ :return: an [N x C x ...] Tensor of outputs.
249
+ """
250
+ return checkpoint(
251
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
252
+ )
253
+
254
+
255
+ def _forward(self, x, emb):
256
+ if self.updown:
257
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
258
+ h = in_rest(x)
259
+ h = self.h_upd(h)
260
+ x = self.x_upd(x)
261
+ h = in_conv(h)
262
+ else:
263
+ h = self.in_layers(x)
264
+ emb_out = self.emb_layers(emb).type(h.dtype)
265
+ while len(emb_out.shape) < len(h.shape):
266
+ emb_out = emb_out[..., None]
267
+ if self.use_scale_shift_norm:
268
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
269
+ scale, shift = th.chunk(emb_out, 2, dim=1)
270
+ h = out_norm(h) * (1 + scale) + shift
271
+ h = out_rest(h)
272
+ else:
273
+ h = h + emb_out
274
+ h = self.out_layers(h)
275
+ return self.skip_connection(x) + h
276
+
277
+
278
+ class AttentionBlock(nn.Module):
279
+ """
280
+ An attention block that allows spatial positions to attend to each other.
281
+ Originally ported from here, but adapted to the N-d case.
282
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
283
+ """
284
+
285
+ def __init__(
286
+ self,
287
+ channels,
288
+ num_heads=1,
289
+ num_head_channels=-1,
290
+ use_checkpoint=False,
291
+ use_new_attention_order=False,
292
+ ):
293
+ super().__init__()
294
+ self.channels = channels
295
+ if num_head_channels == -1:
296
+ self.num_heads = num_heads
297
+ else:
298
+ assert (
299
+ channels % num_head_channels == 0
300
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
301
+ self.num_heads = channels // num_head_channels
302
+ self.use_checkpoint = use_checkpoint
303
+ self.norm = normalization(channels)
304
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
305
+ if use_new_attention_order:
306
+ # split qkv before split heads
307
+ self.attention = QKVAttention(self.num_heads)
308
+ else:
309
+ # split heads before split qkv
310
+ self.attention = QKVAttentionLegacy(self.num_heads)
311
+
312
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
313
+
314
+ def forward(self, x):
315
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
316
+ #return pt_checkpoint(self._forward, x) # pytorch
317
+
318
+ def _forward(self, x):
319
+ b, c, *spatial = x.shape
320
+ x = x.reshape(b, c, -1)
321
+ qkv = self.qkv(self.norm(x))
322
+ h = self.attention(qkv)
323
+ h = self.proj_out(h)
324
+ return (x + h).reshape(b, c, *spatial)
325
+
326
+
327
+ def count_flops_attn(model, _x, y):
328
+ """
329
+ A counter for the `thop` package to count the operations in an
330
+ attention operation.
331
+ Meant to be used like:
332
+ macs, params = thop.profile(
333
+ model,
334
+ inputs=(inputs, timestamps),
335
+ custom_ops={QKVAttention: QKVAttention.count_flops},
336
+ )
337
+ """
338
+ b, c, *spatial = y[0].shape
339
+ num_spatial = int(np.prod(spatial))
340
+ # We perform two matmuls with the same number of ops.
341
+ # The first computes the weight matrix, the second computes
342
+ # the combination of the value vectors.
343
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
344
+ model.total_ops += th.DoubleTensor([matmul_ops])
345
+
346
+
347
+ class QKVAttentionLegacy(nn.Module):
348
+ """
349
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
350
+ """
351
+
352
+ def __init__(self, n_heads):
353
+ super().__init__()
354
+ self.n_heads = n_heads
355
+
356
+ def forward(self, qkv):
357
+ """
358
+ Apply QKV attention.
359
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
360
+ :return: an [N x (H * C) x T] tensor after attention.
361
+ """
362
+ bs, width, length = qkv.shape
363
+ assert width % (3 * self.n_heads) == 0
364
+ ch = width // (3 * self.n_heads)
365
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
366
+ scale = 1 / math.sqrt(math.sqrt(ch))
367
+ weight = th.einsum(
368
+ "bct,bcs->bts", q * scale, k * scale
369
+ ) # More stable with f16 than dividing afterwards
370
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
371
+ a = th.einsum("bts,bcs->bct", weight, v)
372
+ return a.reshape(bs, -1, length)
373
+
374
+ @staticmethod
375
+ def count_flops(model, _x, y):
376
+ return count_flops_attn(model, _x, y)
377
+
378
+
379
+ class QKVAttention(nn.Module):
380
+ """
381
+ A module which performs QKV attention and splits in a different order.
382
+ """
383
+
384
+ def __init__(self, n_heads):
385
+ super().__init__()
386
+ self.n_heads = n_heads
387
+
388
+ def forward(self, qkv):
389
+ """
390
+ Apply QKV attention.
391
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
392
+ :return: an [N x (H * C) x T] tensor after attention.
393
+ """
394
+ bs, width, length = qkv.shape
395
+ assert width % (3 * self.n_heads) == 0
396
+ ch = width // (3 * self.n_heads)
397
+ q, k, v = qkv.chunk(3, dim=1)
398
+ scale = 1 / math.sqrt(math.sqrt(ch))
399
+ weight = th.einsum(
400
+ "bct,bcs->bts",
401
+ (q * scale).view(bs * self.n_heads, ch, length),
402
+ (k * scale).view(bs * self.n_heads, ch, length),
403
+ ) # More stable with f16 than dividing afterwards
404
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
405
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
406
+ return a.reshape(bs, -1, length)
407
+
408
+ @staticmethod
409
+ def count_flops(model, _x, y):
410
+ return count_flops_attn(model, _x, y)
411
+
412
+
413
+ class UNetModel(nn.Module):
414
+ """
415
+ The full UNet model with attention and timestep embedding.
416
+ :param in_channels: channels in the input Tensor.
417
+ :param model_channels: base channel count for the model.
418
+ :param out_channels: channels in the output Tensor.
419
+ :param num_res_blocks: number of residual blocks per downsample.
420
+ :param attention_resolutions: a collection of downsample rates at which
421
+ attention will take place. May be a set, list, or tuple.
422
+ For example, if this contains 4, then at 4x downsampling, attention
423
+ will be used.
424
+ :param dropout: the dropout probability.
425
+ :param channel_mult: channel multiplier for each level of the UNet.
426
+ :param conv_resample: if True, use learned convolutions for upsampling and
427
+ downsampling.
428
+ :param dims: determines if the signal is 1D, 2D, or 3D.
429
+ :param num_classes: if specified (as an int), then this model will be
430
+ class-conditional with `num_classes` classes.
431
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
432
+ :param num_heads: the number of attention heads in each attention layer.
433
+ :param num_heads_channels: if specified, ignore num_heads and instead use
434
+ a fixed channel width per attention head.
435
+ :param num_heads_upsample: works with num_heads to set a different number
436
+ of heads for upsampling. Deprecated.
437
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
438
+ :param resblock_updown: use residual blocks for up/downsampling.
439
+ :param use_new_attention_order: use a different attention pattern for potentially
440
+ increased efficiency.
441
+ """
442
+
443
+ def __init__(
444
+ self,
445
+ image_size,
446
+ in_channels,
447
+ model_channels,
448
+ out_channels,
449
+ num_res_blocks,
450
+ attention_resolutions,
451
+ metadata=None,
452
+ dropout=0,
453
+ channel_mult=(1, 2, 4, 8),
454
+ conv_resample=True,
455
+ dims=2,
456
+ num_classes=None,
457
+ use_checkpoint=False,
458
+ use_fp16=False,
459
+ num_heads=-1,
460
+ num_head_channels=-1,
461
+ num_heads_upsample=-1,
462
+ use_scale_shift_norm=False,
463
+ resblock_updown=False,
464
+ use_new_attention_order=False,
465
+ use_spatial_transformer=False, # custom transformer support
466
+ transformer_depth=1, # custom transformer support
467
+ context_dim=None, # custom transformer support
468
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
469
+ legacy=True,
470
+ disable_self_attentions=None,
471
+ num_attention_blocks=None,
472
+ disable_middle_self_attn=False,
473
+ use_linear_in_transformer=False,
474
+ ):
475
+ super().__init__()
476
+ if use_spatial_transformer:
477
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
478
+
479
+ if context_dim is not None:
480
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
481
+ from omegaconf.listconfig import ListConfig
482
+ if type(context_dim) == ListConfig:
483
+ context_dim = list(context_dim)
484
+
485
+ if num_heads_upsample == -1:
486
+ num_heads_upsample = num_heads
487
+
488
+ if num_heads == -1:
489
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
490
+
491
+ if num_head_channels == -1:
492
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
493
+
494
+ self.image_size = image_size
495
+ self.in_channels = in_channels
496
+ self.model_channels = model_channels
497
+ self.out_channels = out_channels
498
+ if isinstance(num_res_blocks, int):
499
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
500
+ else:
501
+ if len(num_res_blocks) != len(channel_mult):
502
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
503
+ "as a list/tuple (per-level) with the same length as channel_mult")
504
+ self.num_res_blocks = num_res_blocks
505
+ if disable_self_attentions is not None:
506
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
507
+ assert len(disable_self_attentions) == len(channel_mult)
508
+ if num_attention_blocks is not None:
509
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
510
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
511
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
512
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
513
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
514
+ f"attention will still not be set.")
515
+
516
+ self.attention_resolutions = attention_resolutions
517
+ self.dropout = dropout
518
+ self.channel_mult = channel_mult
519
+ self.conv_resample = conv_resample
520
+ self.num_classes = num_classes
521
+ self.use_checkpoint = use_checkpoint
522
+ self.dtype = th.float16 if use_fp16 else th.float32
523
+ self.num_heads = num_heads
524
+ self.num_head_channels = num_head_channels
525
+ self.num_heads_upsample = num_heads_upsample
526
+ self.predict_codebook_ids = n_embed is not None
527
+ # self.metadata_emb=instantiate_from_config(metadata_config)
528
+
529
+ time_embed_dim = model_channels * 4
530
+ self.time_embed = nn.Sequential(
531
+ linear(model_channels, time_embed_dim),
532
+ nn.SiLU(),
533
+ linear(time_embed_dim, time_embed_dim),
534
+ )
535
+
536
+ if self.num_classes is not None:
537
+ if isinstance(self.num_classes, int):
538
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
539
+ elif self.num_classes == "continuous":
540
+ print("setting up linear c_adm embedding layer")
541
+ self.label_emb = nn.Linear(1, time_embed_dim)
542
+ else:
543
+ raise ValueError()
544
+
545
+ self.input_blocks = nn.ModuleList(
546
+ [
547
+ TimestepEmbedSequential(
548
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
549
+ )
550
+ ]
551
+ )
552
+ self._feature_size = model_channels
553
+ input_block_chans = [model_channels]
554
+ ch = model_channels
555
+ ds = 1
556
+ for level, mult in enumerate(channel_mult):
557
+ for nr in range(self.num_res_blocks[level]):
558
+ layers = [
559
+ ResBlock(
560
+ ch,
561
+ time_embed_dim,
562
+ dropout,
563
+ out_channels=mult * model_channels,
564
+ dims=dims,
565
+ use_checkpoint=use_checkpoint,
566
+ use_scale_shift_norm=use_scale_shift_norm,
567
+ )
568
+ ]
569
+ ch = mult * model_channels
570
+ if ds in attention_resolutions:
571
+ if num_head_channels == -1:
572
+ dim_head = ch // num_heads
573
+ else:
574
+ num_heads = ch // num_head_channels
575
+ dim_head = num_head_channels
576
+ if legacy:
577
+ #num_heads = 1
578
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
579
+ if exists(disable_self_attentions):
580
+ disabled_sa = disable_self_attentions[level]
581
+ else:
582
+ disabled_sa = False
583
+
584
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
585
+ layers.append(
586
+ AttentionBlock(
587
+ ch,
588
+ use_checkpoint=use_checkpoint,
589
+ num_heads=num_heads,
590
+ num_head_channels=dim_head,
591
+ use_new_attention_order=use_new_attention_order,
592
+ ) if not use_spatial_transformer else SpatialTransformer(
593
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
594
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
595
+ use_checkpoint=use_checkpoint
596
+ )
597
+ )
598
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
599
+ self._feature_size += ch
600
+ input_block_chans.append(ch)
601
+ if level != len(channel_mult) - 1:
602
+ out_ch = ch
603
+ self.input_blocks.append(
604
+ TimestepEmbedSequential(
605
+ ResBlock(
606
+ ch,
607
+ time_embed_dim,
608
+ dropout,
609
+ out_channels=out_ch,
610
+ dims=dims,
611
+ use_checkpoint=use_checkpoint,
612
+ use_scale_shift_norm=use_scale_shift_norm,
613
+ down=True,
614
+ )
615
+ if resblock_updown
616
+ else Downsample(
617
+ ch, conv_resample, dims=dims, out_channels=out_ch
618
+ )
619
+ )
620
+ )
621
+ ch = out_ch
622
+ input_block_chans.append(ch)
623
+ ds *= 2
624
+ self._feature_size += ch
625
+
626
+ if num_head_channels == -1:
627
+ dim_head = ch // num_heads
628
+ else:
629
+ num_heads = ch // num_head_channels
630
+ dim_head = num_head_channels
631
+ if legacy:
632
+ #num_heads = 1
633
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
634
+ self.middle_block = TimestepEmbedSequential(
635
+ ResBlock(
636
+ ch,
637
+ time_embed_dim,
638
+ dropout,
639
+ dims=dims,
640
+ use_checkpoint=use_checkpoint,
641
+ use_scale_shift_norm=use_scale_shift_norm,
642
+ ),
643
+ AttentionBlock(
644
+ ch,
645
+ use_checkpoint=use_checkpoint,
646
+ num_heads=num_heads,
647
+ num_head_channels=dim_head,
648
+ use_new_attention_order=use_new_attention_order,
649
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
650
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
651
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
652
+ use_checkpoint=use_checkpoint
653
+ ),
654
+ ResBlock(
655
+ ch,
656
+ time_embed_dim,
657
+ dropout,
658
+ dims=dims,
659
+ use_checkpoint=use_checkpoint,
660
+ use_scale_shift_norm=use_scale_shift_norm,
661
+ ),
662
+ )
663
+ self._feature_size += ch
664
+
665
+ self.output_blocks = nn.ModuleList([])
666
+ for level, mult in list(enumerate(channel_mult))[::-1]:
667
+ for i in range(self.num_res_blocks[level] + 1):
668
+ ich = input_block_chans.pop()
669
+ layers = [
670
+ ResBlock(
671
+ ch + ich,
672
+ time_embed_dim,
673
+ dropout,
674
+ out_channels=model_channels * mult,
675
+ dims=dims,
676
+ use_checkpoint=use_checkpoint,
677
+ use_scale_shift_norm=use_scale_shift_norm,
678
+ )
679
+ ]
680
+ ch = model_channels * mult
681
+ if ds in attention_resolutions:
682
+ if num_head_channels == -1:
683
+ dim_head = ch // num_heads
684
+ else:
685
+ num_heads = ch // num_head_channels
686
+ dim_head = num_head_channels
687
+ if legacy:
688
+ #num_heads = 1
689
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
690
+ if exists(disable_self_attentions):
691
+ disabled_sa = disable_self_attentions[level]
692
+ else:
693
+ disabled_sa = False
694
+
695
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
696
+ layers.append(
697
+ AttentionBlock(
698
+ ch,
699
+ use_checkpoint=use_checkpoint,
700
+ num_heads=num_heads_upsample,
701
+ num_head_channels=dim_head,
702
+ use_new_attention_order=use_new_attention_order,
703
+ ) if not use_spatial_transformer else SpatialTransformer(
704
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
705
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
706
+ use_checkpoint=use_checkpoint
707
+ )
708
+ )
709
+ if level and i == self.num_res_blocks[level]:
710
+ out_ch = ch
711
+ layers.append(
712
+ ResBlock(
713
+ ch,
714
+ time_embed_dim,
715
+ dropout,
716
+ out_channels=out_ch,
717
+ dims=dims,
718
+ use_checkpoint=use_checkpoint,
719
+ use_scale_shift_norm=use_scale_shift_norm,
720
+ up=True,
721
+ )
722
+ if resblock_updown
723
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
724
+ )
725
+ ds //= 2
726
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
727
+ self._feature_size += ch
728
+
729
+ self.out = nn.Sequential(
730
+ normalization(ch),
731
+ nn.SiLU(),
732
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
733
+ )
734
+ if self.predict_codebook_ids:
735
+ self.id_predictor = nn.Sequential(
736
+ normalization(ch),
737
+ conv_nd(dims, model_channels, n_embed, 1),
738
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
739
+ )
740
+
741
+ def convert_to_fp16(self):
742
+ """
743
+ Convert the torso of the model to float16.
744
+ """
745
+ self.input_blocks.apply(convert_module_to_f16)
746
+ self.middle_block.apply(convert_module_to_f16)
747
+ self.output_blocks.apply(convert_module_to_f16)
748
+
749
+ def convert_to_fp32(self):
750
+ """
751
+ Convert the torso of the model to float32.
752
+ """
753
+ self.input_blocks.apply(convert_module_to_f32)
754
+ self.middle_block.apply(convert_module_to_f32)
755
+ self.output_blocks.apply(convert_module_to_f32)
756
+
757
+ def forward(self, x, timesteps=None, metadata=None,context=None, y=None,**kwargs):
758
+ """
759
+ Apply the model to an input batch.
760
+ :param x: an [N x C x ...] Tensor of inputs.
761
+ :param timesteps: a 1-D batch of timesteps.
762
+ :param context: conditioning plugged in via crossattn
763
+ :param y: an [N] Tensor of labels, if class-conditional.
764
+ :return: an [N x C x ...] Tensor of outputs.
765
+ """
766
+ if len(metadata)==1:
767
+ metadata=metadata[0]
768
+ assert (y is not None) == (
769
+ self.num_classes is not None
770
+ ), "must specify y if and only if the model is class-conditional"
771
+ hs = []
772
+ t_emb = timestep_embedding(timesteps, self.model_channels,repeat_only=False)
773
+
774
+ emb = self.time_embed(t_emb)
775
+ emb+=metadata
776
+
777
+
778
+ if self.num_classes is not None:
779
+ assert y.shape[0] == x.shape[0]
780
+ emb = emb + self.label_emb(y)
781
+
782
+ h = x.type(self.dtype)
783
+ for module in self.input_blocks:
784
+ h = module(h, emb, context)
785
+ hs.append(h)
786
+ h = self.middle_block(h, emb, context)
787
+ for module in self.output_blocks:
788
+ h = th.cat([h, hs.pop()], dim=1)
789
+ h = module(h, emb, context)
790
+ h = h.type(x.dtype)
791
+ if self.predict_codebook_ids:
792
+ return self.id_predictor(h)
793
+ else:
794
+ return self.out(h)
crs_core/modules/diffusionmodules/util.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from OpenAI improved-diffusion and guided-diffusion (nn.py)
2
+
3
+
4
+ import math
5
+ import torch
6
+ import torch.nn as nn
7
+ from einops import repeat
8
+
9
+
10
+ def checkpoint(func, inputs, params, flag):
11
+ """
12
+ Evaluate a function without caching intermediate activations, allowing for
13
+ reduced memory at the expense of extra compute in the backward pass.
14
+ """
15
+ if flag:
16
+ args = tuple(inputs) + tuple(params)
17
+ return CheckpointFunction.apply(func, len(inputs), *args)
18
+ else:
19
+ return func(*inputs)
20
+
21
+
22
+ class CheckpointFunction(torch.autograd.Function):
23
+ @staticmethod
24
+ def forward(ctx, run_function, length, *args):
25
+ ctx.run_function = run_function
26
+ ctx.input_tensors = list(args[:length])
27
+ ctx.input_params = list(args[length:])
28
+ ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
29
+ "dtype": torch.get_autocast_gpu_dtype(),
30
+ "cache_enabled": torch.is_autocast_cache_enabled()}
31
+ with torch.no_grad():
32
+ output_tensors = ctx.run_function(*ctx.input_tensors)
33
+ return output_tensors
34
+
35
+ @staticmethod
36
+ def backward(ctx, *output_grads):
37
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
38
+ with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
39
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
40
+ output_tensors = ctx.run_function(*shallow_copies)
41
+ input_grads = torch.autograd.grad(
42
+ output_tensors,
43
+ ctx.input_tensors + ctx.input_params,
44
+ output_grads,
45
+ allow_unused=True,
46
+ )
47
+ del ctx.input_tensors
48
+ del ctx.input_params
49
+ del output_tensors
50
+ return (None, None) + input_grads
51
+
52
+
53
+ class SinusoidalEmbedding(nn.Module):
54
+ def __init__(self, max_value, embedding_dim):
55
+ super(SinusoidalEmbedding, self).__init__()
56
+ self.max_value = max_value
57
+ self.embedding_dim = embedding_dim
58
+ self.omega = 10000
59
+
60
+ def forward(self, k):
61
+ k_normalized = k * self.max_value
62
+ embedding = torch.zeros((k.size(0), k.size(1), self.embedding_dim), device=k.device)
63
+ for j in range(k.size(1)):
64
+ for i in range(self.embedding_dim // 2):
65
+ embedding[:, j, 2 * i] = torch.sin(k_normalized[:, j] * (self.omega ** (-2 * i / self.embedding_dim)))
66
+ embedding[:, j, 2 * i + 1] = torch.cos(k_normalized[:, j] * (self.omega ** (-2 * i / self.embedding_dim)))
67
+ return embedding.view(k.size(0), -1)
68
+
69
+
70
+ def create_condition_vector(metadata, mlp_models, embedding_dim):
71
+ metadata_embeddings = [mlp_models[j](metadata[:, j*embedding_dim:(j+1)*embedding_dim]) for j in range(len(mlp_models))]
72
+ return sum(metadata_embeddings)
73
+
74
+
75
+ def timestep_embedding_t(timesteps, dim, max_period=10000, repeat_only=False):
76
+ if not repeat_only:
77
+ half = dim // 2
78
+ freqs = torch.exp(
79
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
80
+ ).to(device=timesteps.device)
81
+ args = timesteps[:, None].float() * freqs[None]
82
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
83
+ if dim % 2:
84
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
85
+ else:
86
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
87
+ return embedding
88
+
89
+
90
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
91
+ if repeat_only:
92
+ return repeat(timesteps, 'b -> b d', d=dim)
93
+ half = dim // 2
94
+ freqs = torch.exp(
95
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
96
+ ).to(device=timesteps.device)
97
+ args = timesteps[:, None].float() * freqs[None]
98
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
99
+ if dim % 2:
100
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
101
+ return embedding
102
+
103
+
104
+ def zero_module(module):
105
+ for p in module.parameters():
106
+ p.detach().zero_()
107
+ return module
108
+
109
+
110
+ def normalization(channels):
111
+ return GroupNorm32(32, channels)
112
+
113
+
114
+ class GroupNorm32(nn.GroupNorm):
115
+ def forward(self, x):
116
+ return super().forward(x.float()).type(x.dtype)
117
+
118
+
119
+ def conv_nd(dims, *args, **kwargs):
120
+ if dims == 1:
121
+ return nn.Conv1d(*args, **kwargs)
122
+ elif dims == 2:
123
+ return nn.Conv2d(*args, **kwargs)
124
+ elif dims == 3:
125
+ return nn.Conv3d(*args, **kwargs)
126
+ raise ValueError(f"unsupported dimensions: {dims}")
127
+
128
+
129
+ def linear(*args, **kwargs):
130
+ return nn.Linear(*args, **kwargs)
131
+
132
+
133
+ def avg_pool_nd(dims, *args, **kwargs):
134
+ if dims == 1:
135
+ return nn.AvgPool1d(*args, **kwargs)
136
+ elif dims == 2:
137
+ return nn.AvgPool2d(*args, **kwargs)
138
+ elif dims == 3:
139
+ return nn.AvgPool3d(*args, **kwargs)
140
+ raise ValueError(f"unsupported dimensions: {dims}")
crs_core/modules/distributions/__init__.py ADDED
File without changes
crs_core/modules/distributions/distributions.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class AbstractDistribution:
6
+ def sample(self):
7
+ raise NotImplementedError()
8
+
9
+ def mode(self):
10
+ raise NotImplementedError()
11
+
12
+
13
+ class DiracDistribution(AbstractDistribution):
14
+ def __init__(self, value):
15
+ self.value = value
16
+
17
+ def sample(self):
18
+ return self.value
19
+
20
+ def mode(self):
21
+ return self.value
22
+
23
+
24
+ class DiagonalGaussianDistribution(object):
25
+ def __init__(self, parameters, deterministic=False):
26
+ self.parameters = parameters
27
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
+ self.deterministic = deterministic
30
+ self.std = torch.exp(0.5 * self.logvar)
31
+ self.var = torch.exp(self.logvar)
32
+ if self.deterministic:
33
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34
+
35
+ def sample(self):
36
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37
+ return x
38
+
39
+ def kl(self, other=None):
40
+ if self.deterministic:
41
+ return torch.Tensor([0.])
42
+ else:
43
+ if other is None:
44
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
45
+ + self.var - 1.0 - self.logvar,
46
+ dim=[1, 2, 3])
47
+ else:
48
+ return 0.5 * torch.sum(
49
+ torch.pow(self.mean - other.mean, 2) / other.var
50
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
51
+ dim=[1, 2, 3])
52
+
53
+ def nll(self, sample, dims=[1,2,3]):
54
+ if self.deterministic:
55
+ return torch.Tensor([0.])
56
+ logtwopi = np.log(2.0 * np.pi)
57
+ return 0.5 * torch.sum(
58
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59
+ dim=dims)
60
+
61
+ def mode(self):
62
+ return self.mean
63
+
64
+
65
+ def normal_kl(mean1, logvar1, mean2, logvar2):
66
+ """
67
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68
+ Compute the KL divergence between two gaussians.
69
+ Shapes are automatically broadcasted, so batches can be compared to
70
+ scalars, among other use cases.
71
+ """
72
+ tensor = None
73
+ for obj in (mean1, logvar1, mean2, logvar2):
74
+ if isinstance(obj, torch.Tensor):
75
+ tensor = obj
76
+ break
77
+ assert tensor is not None, "at least one argument must be a Tensor"
78
+
79
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
80
+ # Tensors, but it does not work for torch.exp().
81
+ logvar1, logvar2 = [
82
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83
+ for x in (logvar1, logvar2)
84
+ ]
85
+
86
+ return 0.5 * (
87
+ -1.0
88
+ + logvar2
89
+ - logvar1
90
+ + torch.exp(logvar1 - logvar2)
91
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92
+ )
crs_core/text_encoder.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import CLIPTextModel, CLIPTokenizer
4
+
5
+
6
+ class FrozenCLIPEmbedder(nn.Module):
7
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, freeze=True, layer="last", layer_idx=None):
8
+ super().__init__()
9
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
10
+ self.transformer = CLIPTextModel.from_pretrained(version)
11
+ self.device = device
12
+ self.max_length = max_length
13
+ self.layer = layer
14
+ self.layer_idx = layer_idx
15
+ if freeze:
16
+ self.transformer = self.transformer.eval()
17
+ for p in self.parameters():
18
+ p.requires_grad = False
19
+
20
+ def forward(self, text):
21
+ enc = self.tokenizer(
22
+ text, truncation=True, max_length=self.max_length,
23
+ return_length=True, return_overflowing_tokens=False,
24
+ padding="max_length", return_tensors="pt"
25
+ )
26
+ tokens = enc["input_ids"].to(next(self.transformer.parameters()).device)
27
+ out = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden")
28
+ if self.layer == "last":
29
+ return out.last_hidden_state
30
+ if self.layer == "pooled":
31
+ return out.pooler_output[:, None, :]
32
+ return out.hidden_states[self.layer_idx]
33
+
34
+ def encode(self, text):
35
+ return self(text)
crs_core/utils.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+
4
+ def exists(val):
5
+ return val is not None
6
+
7
+
8
+ def get_obj_from_str(string, reload=False):
9
+ module, cls = string.rsplit('.', 1)
10
+ if reload:
11
+ module_imp = importlib.import_module(module)
12
+ importlib.reload(module_imp)
13
+ return getattr(importlib.import_module(module, package=None), cls)
14
+
15
+
16
+ def instantiate_from_config(config):
17
+ if "target" not in config:
18
+ raise KeyError("Expected key `target` in config")
19
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
global_content_adapter/config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "in_dim": 768,
3
+ "channel_mult": [
4
+ 2,
5
+ 4
6
+ ],
7
+ "_target": "crs_core.global_adapter.GlobalContentAdapter"
8
+ }
global_content_adapter/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1fb4e62e0079a9a1b1707435048ac213fa50beabd830c71f29639c71da4259c8
3
+ size 188855312
global_text_adapter/config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "in_dim": 768,
3
+ "_target": "crs_core.global_adapter.GlobalTextAdapter"
4
+ }
local_adapter/config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "in_channels": 4,
3
+ "model_channels": 320,
4
+ "local_channels": 18,
5
+ "inject_channels": [
6
+ 192,
7
+ 256,
8
+ 384,
9
+ 512
10
+ ],
11
+ "inject_layers": [
12
+ 1,
13
+ 4,
14
+ 7,
15
+ 10
16
+ ],
17
+ "num_res_blocks": 2,
18
+ "attention_resolutions": [
19
+ 4,
20
+ 2,
21
+ 1
22
+ ],
23
+ "channel_mult": [
24
+ 1,
25
+ 2,
26
+ 4,
27
+ 4
28
+ ],
29
+ "use_checkpoint": true,
30
+ "num_heads": 8,
31
+ "use_spatial_transformer": true,
32
+ "transformer_depth": 1,
33
+ "context_dim": 768,
34
+ "legacy": false,
35
+ "_target": "crs_core.local_adapter.LocalAdapter"
36
+ }
local_adapter/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8e68605b04077b6cea4861af89a346c14fef5732c92b6d246038cfd24c85a283
3
+ size 1677896968
metadata_encoder/config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "max_value": 1000,
3
+ "embedding_dim": 320,
4
+ "metadata_dim": 7,
5
+ "max_period": 10000,
6
+ "_target": "crs_core.metadata_embedding.metadata_embeddings"
7
+ }
metadata_encoder/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de051e7bb2b0a61e1c5ad28178b3116b76dc2b686ac9aaa073cd54c7366fe5a4
3
+ size 11505912
model_index.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "CRSDiffPipeline",
3
+ "_diffusers_version": "0.32.2",
4
+ "crs_model": [
5
+ "pipeline",
6
+ "CRSDiffPipeline"
7
+ ],
8
+ "scheduler": [
9
+ "diffusers",
10
+ "DDIMScheduler"
11
+ ],
12
+ "scale_factor": 0.18215,
13
+ "conditioning_key": "crossattn",
14
+ "channels": 4
15
+ }
modular_pipeline.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CRS-Diff modular loading utilities for custom diffusers pipeline."""
2
+
3
+ import importlib
4
+ import json
5
+ import sys
6
+ from pathlib import Path
7
+ from typing import Dict, Optional, Union
8
+
9
+ import torch
10
+ from diffusers import DDIMScheduler
11
+
12
+ _PIPELINE_DIR = Path(__file__).resolve().parent
13
+ if str(_PIPELINE_DIR) not in sys.path:
14
+ sys.path.insert(0, str(_PIPELINE_DIR))
15
+
16
+ _COMPONENT_NAMES = (
17
+ "unet",
18
+ "vae",
19
+ "text_encoder",
20
+ "local_adapter",
21
+ "global_content_adapter",
22
+ "global_text_adapter",
23
+ "metadata_encoder",
24
+ )
25
+
26
+ _TARGET_MAP = {
27
+ "crs_core.local_adapter.LocalControlUNetModel": "crs_core.local_adapter.LocalControlUNetModel",
28
+ "crs_core.autoencoder.AutoencoderKL": "crs_core.autoencoder.AutoencoderKL",
29
+ "crs_core.text_encoder.FrozenCLIPEmbedder": "crs_core.text_encoder.FrozenCLIPEmbedder",
30
+ "crs_core.local_adapter.LocalAdapter": "crs_core.local_adapter.LocalAdapter",
31
+ "crs_core.global_adapter.GlobalContentAdapter": "crs_core.global_adapter.GlobalContentAdapter",
32
+ "crs_core.global_adapter.GlobalTextAdapter": "crs_core.global_adapter.GlobalTextAdapter",
33
+ "crs_core.metadata_embedding.metadata_embeddings": "crs_core.metadata_embedding.metadata_embeddings",
34
+ }
35
+
36
+
37
+ def ensure_model_path(pretrained_model_name_or_path: Union[str, Path]) -> Path:
38
+ """Resolve local path or download HF repo snapshot."""
39
+ path = Path(pretrained_model_name_or_path)
40
+ if not path.exists():
41
+ from huggingface_hub import snapshot_download
42
+
43
+ path = Path(snapshot_download(str(pretrained_model_name_or_path)))
44
+ path = path.resolve()
45
+ if str(path) not in sys.path:
46
+ sys.path.insert(0, str(path))
47
+ return path
48
+
49
+
50
+ def resolve_model_root(candidate: Optional[Union[str, Path]]) -> Optional[Path]:
51
+ """Resolve to folder containing model_index.json."""
52
+ if not candidate:
53
+ return None
54
+ path = ensure_model_path(candidate)
55
+ if (path / "model_index.json").exists():
56
+ return path
57
+ cur = path
58
+ for _ in range(5):
59
+ parent = cur.parent
60
+ if parent == cur:
61
+ break
62
+ if (parent / "model_index.json").exists():
63
+ return parent
64
+ cur = parent
65
+ return None
66
+
67
+
68
+ def _get_class(target: str):
69
+ module_path, cls_name = target.rsplit(".", 1)
70
+ mod = importlib.import_module(module_path)
71
+ return getattr(mod, cls_name)
72
+
73
+
74
+ def load_component(model_root: Path, name: str):
75
+ """Load single split component from <repo>/<name>/."""
76
+ root = Path(model_root)
77
+ comp_path = root / name
78
+ with (comp_path / "config.json").open("r", encoding="utf-8") as f:
79
+ cfg = json.load(f)
80
+ target = cfg.pop("_target", None)
81
+ if not target:
82
+ raise ValueError(f"No _target in {comp_path / 'config.json'}")
83
+ target = _TARGET_MAP.get(target, target)
84
+ cls_ref = _get_class(target)
85
+ params = {k: v for k, v in cfg.items() if not k.startswith("_")}
86
+ module = cls_ref(**params)
87
+
88
+ weight_file = comp_path / "diffusion_pytorch_model.safetensors"
89
+ if weight_file.exists():
90
+ from safetensors.torch import load_file
91
+
92
+ state = load_file(str(weight_file))
93
+ module.load_state_dict(state, strict=True)
94
+ module.eval()
95
+ return module
96
+
97
+
98
+ class CRSModelWrapper(torch.nn.Module):
99
+ """Wrap split components to mimic CRSControlNet APIs used by pipeline."""
100
+
101
+ def __init__(
102
+ self,
103
+ unet,
104
+ vae,
105
+ text_encoder,
106
+ local_adapter,
107
+ global_content_adapter,
108
+ global_text_adapter,
109
+ metadata_encoder,
110
+ channels: int = 4,
111
+ ):
112
+ super().__init__()
113
+ self.model = torch.nn.Module()
114
+ self.model.add_module("diffusion_model", unet)
115
+ self.first_stage_model = vae
116
+ self.cond_stage_model = text_encoder
117
+ self.local_adapter = local_adapter
118
+ self.global_content_adapter = global_content_adapter
119
+ self.global_text_adapter = global_text_adapter
120
+ self.metadata_emb = metadata_encoder
121
+ self.local_control_scales = [1.0] * 13
122
+ self.channels = channels
123
+
124
+ @torch.no_grad()
125
+ def get_learned_conditioning(self, prompts):
126
+ if hasattr(self.cond_stage_model, "device"):
127
+ self.cond_stage_model.device = str(next(self.parameters()).device)
128
+ return self.cond_stage_model.encode(prompts)
129
+
130
+ def apply_model(self, x_noisy, t, cond, metadata=None, global_strength=1.0, **kwargs):
131
+ del kwargs
132
+ if metadata is None:
133
+ metadata = cond["metadata"]
134
+ cond_txt = torch.cat(cond["c_crossattn"], 1)
135
+
136
+ if cond.get("global_control") is not None and cond["global_control"][0] is not None:
137
+ metadata = self.metadata_emb(metadata)
138
+ content_t, _ = cond["global_control"][0].chunk(2, dim=1)
139
+ global_control = self.global_content_adapter(content_t)
140
+ cond_txt = self.global_text_adapter(cond_txt)
141
+ cond_txt = torch.cat([cond_txt, global_strength * global_control], dim=1)
142
+
143
+ local_control = None
144
+ if cond.get("local_control") is not None and cond["local_control"][0] is not None:
145
+ local_control = torch.cat(cond["local_control"], 1)
146
+ local_control = self.local_adapter(
147
+ x=x_noisy, timesteps=t, context=cond_txt, local_conditions=local_control
148
+ )
149
+ local_control = [c * s for c, s in zip(local_control, self.local_control_scales)]
150
+
151
+ return self.model.diffusion_model(
152
+ x=x_noisy,
153
+ timesteps=t,
154
+ metadata=metadata,
155
+ context=cond_txt,
156
+ local_control=local_control,
157
+ meta=True,
158
+ )
159
+
160
+ def decode_first_stage(self, z):
161
+ return self.first_stage_model.decode(z)
162
+
163
+
164
+ def load_components(model_root: Union[str, Path]) -> Dict[str, object]:
165
+ """Load pipeline components from split directories."""
166
+ root = ensure_model_path(model_root)
167
+ scheduler = DDIMScheduler.from_pretrained(root, subfolder="scheduler")
168
+
169
+ scale_factor = 0.18215
170
+ channels = 4
171
+ if (root / "model_index.json").exists():
172
+ with (root / "model_index.json").open("r", encoding="utf-8") as f:
173
+ idx = json.load(f)
174
+ scale_factor = float(idx.get("scale_factor", scale_factor))
175
+ channels = int(idx.get("channels", channels))
176
+
177
+ has_split_components = all((root / name / "config.json").exists() for name in _COMPONENT_NAMES)
178
+ if not has_split_components:
179
+ missing = [name for name in _COMPONENT_NAMES if not (root / name / "config.json").exists()]
180
+ raise FileNotFoundError(
181
+ f"CRS-Diff split component export incomplete. Missing: {missing}. "
182
+ "Expected split folders with config.json and weights."
183
+ )
184
+
185
+ loaded = {name: load_component(root, name) for name in _COMPONENT_NAMES}
186
+ crs_model = CRSModelWrapper(
187
+ unet=loaded["unet"],
188
+ vae=loaded["vae"],
189
+ text_encoder=loaded["text_encoder"],
190
+ local_adapter=loaded["local_adapter"],
191
+ global_content_adapter=loaded["global_content_adapter"],
192
+ global_text_adapter=loaded["global_text_adapter"],
193
+ metadata_encoder=loaded["metadata_encoder"],
194
+ channels=channels,
195
+ )
196
+
197
+ return {"crs_model": crs_model, "scheduler": scheduler, "scale_factor": scale_factor}
pipeline.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import List, Optional, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from diffusers import DDIMScheduler, DiffusionPipeline
9
+ from diffusers.utils import BaseOutput
10
+ from PIL import Image
11
+
12
+ _ROOT = Path(__file__).resolve().parent
13
+ if str(_ROOT) not in sys.path:
14
+ sys.path.insert(0, str(_ROOT))
15
+
16
+ # Register alias for cached custom-pipeline imports.
17
+ sys.modules["pipeline"] = sys.modules[__name__]
18
+
19
+ from modular_pipeline import load_components, resolve_model_root # noqa: E402
20
+
21
+
22
+ @dataclass
23
+ class CRSDiffPipelineOutput(BaseOutput):
24
+ images: List[Image.Image]
25
+
26
+
27
+ class CRSDiffPipeline(DiffusionPipeline):
28
+ def register_modules(self, **kwargs):
29
+ for name, module in kwargs.items():
30
+ if module is None or (
31
+ isinstance(module, (tuple, list)) and len(module) > 0 and module[0] is None
32
+ ):
33
+ self.register_to_config(**{name: (None, None)})
34
+ setattr(self, name, module)
35
+ elif _is_component_list(module):
36
+ self.register_to_config(**{name: (module[0], module[1])})
37
+ setattr(self, name, module)
38
+ else:
39
+ from diffusers.pipelines.pipeline_loading_utils import _fetch_class_library_tuple
40
+
41
+ library, class_name = _fetch_class_library_tuple(module)
42
+ self.register_to_config(**{name: (library, class_name)})
43
+ setattr(self, name, module)
44
+
45
+ def __init__(
46
+ self,
47
+ crs_model=None,
48
+ scheduler=None,
49
+ scale_factor: float = 0.18215,
50
+ model_path: Optional[Union[str, Path]] = None,
51
+ _name_or_path: Optional[Union[str, Path]] = None,
52
+ ):
53
+ super().__init__()
54
+ if _is_component_list(crs_model) or _is_component_list(scheduler):
55
+ model_root = (
56
+ resolve_model_root(model_path)
57
+ or resolve_model_root(_name_or_path)
58
+ or resolve_model_root(getattr(getattr(self, "config", None), "_name_or_path", None))
59
+ )
60
+ if model_root is None:
61
+ raise ValueError(
62
+ "CRS-Diff received config placeholders but could not resolve model path. "
63
+ "Pass `model_path` or load via DiffusionPipeline.from_pretrained(<path>, custom_pipeline=...)."
64
+ )
65
+ loaded = load_components(model_root)
66
+ crs_model = loaded["crs_model"]
67
+ scheduler = loaded["scheduler"]
68
+ scale_factor = loaded["scale_factor"]
69
+
70
+ self.register_modules(crs_model=crs_model, scheduler=scheduler)
71
+ self.vae_scale_factor = scale_factor
72
+
73
+ @property
74
+ def device(self) -> torch.device:
75
+ params = list(self.crs_model.parameters())
76
+ if params:
77
+ return params[0].device
78
+ return torch.device("cpu")
79
+
80
+ @classmethod
81
+ def from_pretrained(
82
+ cls,
83
+ pretrained_model_name_or_path: Union[str, Path],
84
+ device: Optional[Union[str, torch.device]] = None,
85
+ subfolder: Optional[str] = None,
86
+ **kwargs,
87
+ ) -> "CRSDiffPipeline":
88
+ path = resolve_model_root(pretrained_model_name_or_path)
89
+ if path is None:
90
+ raise ValueError(f"Could not resolve CRS-Diff model root from: {pretrained_model_name_or_path}")
91
+
92
+ subfolder = kwargs.pop("subfolder", subfolder)
93
+ if subfolder == "scheduler":
94
+ return DDIMScheduler.from_pretrained(path, subfolder="scheduler")
95
+
96
+ loaded = load_components(path)
97
+ pipe = cls(crs_model=loaded["crs_model"], scheduler=loaded["scheduler"], scale_factor=loaded["scale_factor"])
98
+ if device is not None:
99
+ pipe = pipe.to(device)
100
+ return pipe
101
+
102
+ def _to_tensor(self, x, device: torch.device, dtype=torch.float32) -> torch.Tensor:
103
+ if isinstance(x, np.ndarray):
104
+ x = torch.from_numpy(x)
105
+ if not isinstance(x, torch.Tensor):
106
+ raise TypeError("Expected torch.Tensor or np.ndarray for conditioning inputs.")
107
+ return x.to(device=device, dtype=dtype)
108
+
109
+ @torch.no_grad()
110
+ def __call__(
111
+ self,
112
+ prompt: Union[str, List[str]],
113
+ local_control,
114
+ global_control,
115
+ metadata,
116
+ negative_prompt: Union[str, List[str]] = "",
117
+ num_inference_steps: int = 50,
118
+ guidance_scale: float = 7.5,
119
+ eta: float = 0.0,
120
+ strength: float = 1.0,
121
+ global_strength: float = 1.0,
122
+ generator: Optional[torch.Generator] = None,
123
+ output_type: str = "pil",
124
+ ) -> CRSDiffPipelineOutput:
125
+ device = self.device
126
+ local_control = self._to_tensor(local_control, device=device)
127
+ global_control = self._to_tensor(global_control, device=device)
128
+ metadata = self._to_tensor(metadata, device=device)
129
+
130
+ batch_size = local_control.shape[0]
131
+ if isinstance(prompt, str):
132
+ prompt = [prompt] * batch_size
133
+ if isinstance(negative_prompt, str):
134
+ negative_prompt = [negative_prompt] * batch_size
135
+
136
+ if metadata.dim() == 1:
137
+ metadata = metadata.unsqueeze(0).repeat(batch_size, 1)
138
+
139
+ cond = {
140
+ "local_control": [local_control],
141
+ "c_crossattn": [self.crs_model.get_learned_conditioning(prompt)],
142
+ "global_control": [global_control],
143
+ }
144
+ un_cond = {
145
+ "local_control": [local_control],
146
+ "c_crossattn": [self.crs_model.get_learned_conditioning(negative_prompt)],
147
+ "global_control": [torch.zeros_like(global_control)],
148
+ }
149
+
150
+ if hasattr(self.crs_model, "local_control_scales"):
151
+ self.crs_model.local_control_scales = [strength] * 13
152
+
153
+ _, _, h, w = local_control.shape
154
+ latents = torch.randn(
155
+ (batch_size, self.crs_model.channels, h // 8, w // 8),
156
+ generator=generator,
157
+ device=device,
158
+ )
159
+ latents = latents * self.scheduler.init_noise_sigma
160
+
161
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
162
+ for t in self.scheduler.timesteps:
163
+ ts = torch.full((batch_size,), int(t), device=device, dtype=torch.long)
164
+ if guidance_scale > 1.0:
165
+ noise_text = self.crs_model.apply_model(latents, ts, cond, metadata, global_strength)
166
+ noise_uncond = self.crs_model.apply_model(latents, ts, un_cond, metadata, global_strength)
167
+ noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond)
168
+ else:
169
+ noise_pred = self.crs_model.apply_model(latents, ts, cond, metadata, global_strength)
170
+
171
+ latents = self.scheduler.step(
172
+ model_output=noise_pred,
173
+ timestep=t,
174
+ sample=latents,
175
+ eta=eta,
176
+ generator=generator,
177
+ return_dict=True,
178
+ ).prev_sample
179
+
180
+ images = self.crs_model.decode_first_stage(latents)
181
+ images = images.clamp(-1, 1)
182
+ images = ((images + 1.0) / 2.0).permute(0, 2, 3, 1).cpu().numpy()
183
+ images = (images * 255.0).clip(0, 255).astype(np.uint8)
184
+
185
+ if output_type == "pil":
186
+ images = [Image.fromarray(img) for img in images]
187
+ elif output_type != "numpy":
188
+ raise ValueError("output_type must be 'pil' or 'numpy'")
189
+
190
+ return CRSDiffPipelineOutput(images=images)
191
+
192
+
193
+ def _is_component_list(v):
194
+ return (
195
+ isinstance(v, (list, tuple))
196
+ and len(v) == 2
197
+ and isinstance(v[0], str)
198
+ and isinstance(v[1], str)
199
+ )
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DDIMScheduler",
3
+ "_diffusers_version": "0.36.0",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.00085,
7
+ "clip_sample": false,
8
+ "clip_sample_range": 1.0,
9
+ "dynamic_thresholding_ratio": 0.995,
10
+ "num_train_timesteps": 1000,
11
+ "prediction_type": "epsilon",
12
+ "rescale_betas_zero_snr": false,
13
+ "sample_max_value": 1.0,
14
+ "set_alpha_to_one": false,
15
+ "steps_offset": 0,
16
+ "thresholding": false,
17
+ "timestep_spacing": "leading",
18
+ "trained_betas": null
19
+ }
text_encoder/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "_target": "crs_core.text_encoder.FrozenCLIPEmbedder"
3
+ }
text_encoder/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:651247bce4134453769880497b0ff59124fe047ee7cd7c91ed55308e6503195d
3
+ size 492267488
unet/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "image_size": 32,
3
+ "in_channels": 4,
4
+ "model_channels": 320,
5
+ "out_channels": 4,
6
+ "num_res_blocks": 2,
7
+ "attention_resolutions": [
8
+ 4,
9
+ 2,
10
+ 1
11
+ ],
12
+ "channel_mult": [
13
+ 1,
14
+ 2,
15
+ 4,
16
+ 4
17
+ ],
18
+ "use_checkpoint": true,
19
+ "num_heads": 8,
20
+ "use_spatial_transformer": true,
21
+ "transformer_depth": 1,
22
+ "context_dim": 768,
23
+ "legacy": false,
24
+ "_target": "crs_core.local_adapter.LocalControlUNetModel"
25
+ }
unet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:56afd34c05a153aec419a4a5516da3b1dce24e62bf4bc9ced88a7298fe7d6973
3
+ size 3438164120
vae/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embed_dim": 4,
3
+ "monitor": "val/rec_loss",
4
+ "ddconfig": {
5
+ "double_z": true,
6
+ "z_channels": 4,
7
+ "resolution": 256,
8
+ "in_channels": 3,
9
+ "out_ch": 3,
10
+ "ch": 128,
11
+ "ch_mult": [
12
+ 1,
13
+ 2,
14
+ 4,
15
+ 4
16
+ ],
17
+ "num_res_blocks": 2,
18
+ "attn_resolutions": [],
19
+ "dropout": 0.0
20
+ },
21
+ "lossconfig": {
22
+ "target": "torch.nn.Identity"
23
+ },
24
+ "_target": "crs_core.autoencoder.AutoencoderKL"
25
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce952e59654ae764f1f53f8f40da9eece9fcea54d6e26f12ce9bf5124ba5617e
3
+ size 334640988