BiliSakura commited on
Commit
2ccd4c6
·
verified ·
1 Parent(s): 4ad5f15

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -1,3 +1,87 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DiffusionSat Custom Pipelines
2
+
3
+ Custom community pipelines for loading DiffusionSat checkpoints directly with `diffusers.DiffusionPipeline.from_pretrained()`.
4
+
5
+ > See [Diffusers Community Pipeline Documentation](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview)
6
+
7
+ ## Available Pipelines
8
+
9
+ This directory contains two custom pipelines:
10
+
11
+ 1. **`pipeline_diffusionsat.py`**: Standard text-to-image pipeline with DiffusionSat metadata support.
12
+ 2. **`pipeline_diffusionsat_controlnet.py`**: ControlNet pipeline with DiffusionSat metadata and conditional metadata support.
13
+
14
+ ## Setup
15
+
16
+ The checkpoint folder (`ckpt/diffusionsat/`) should contain the standard diffusers components (unet, vae, scheduler, etc.). You can reference these pipeline files directly from this directory or copy them to your checkpoint folder.
17
+
18
+ ## Usage
19
+
20
+ ### 1. Text-to-Image Pipeline
21
+
22
+ Use `pipeline_diffusionsat.py` for standard generation.
23
+
24
+ ```python
25
+ import torch
26
+ from diffusers import DiffusionPipeline
27
+
28
+ # Load pipeline
29
+ pipe = DiffusionPipeline.from_pretrained(
30
+ "path/to/ckpt/diffusionsat",
31
+ custom_pipeline="./custom_pipelines/pipeline_diffusionsat.py", # Path to this file
32
+ torch_dtype=torch.float16,
33
+ trust_remote_code=True,
34
+ )
35
+ pipe = pipe.to("cuda")
36
+
37
+ # Optional: Metadata (normalized lat, lon, timestamp, GSD, etc.)
38
+ # metadata = [0.5, -0.3, 0.7, 0.2, 0.1, 0.0, 0.5]
39
+
40
+ # Generate
41
+ image = pipe(
42
+ "satellite image of farmland",
43
+ metadata=None, # Optional
44
+ num_inference_steps=30,
45
+ ).images[0]
46
+ ```
47
+
48
+ ### 2. ControlNet Pipeline
49
+
50
+ Use `pipeline_diffusionsat_controlnet.py` for ControlNet generation.
51
+
52
+ ```python
53
+ import torch
54
+ from diffusers import DiffusionPipeline, ControlNetModel
55
+ from diffusers.utils import load_image
56
+
57
+ # 1. Load ControlNet
58
+ controlnet = ControlNetModel.from_pretrained(
59
+ "path/to/ckpt/diffusionsat/controlnet",
60
+ torch_dtype=torch.float16
61
+ )
62
+
63
+ # 2. Load Pipeline with ControlNet
64
+ pipe = DiffusionPipeline.from_pretrained(
65
+ "path/to/ckpt/diffusionsat",
66
+ controlnet=controlnet,
67
+ custom_pipeline="./custom_pipelines/pipeline_diffusionsat_controlnet.py", # Path to this file
68
+ torch_dtype=torch.float16,
69
+ trust_remote_code=True,
70
+ )
71
+ pipe = pipe.to("cuda")
72
+
73
+ # 3. Prepare Control Image
74
+ control_image = load_image("path/to/conditioning_image.png")
75
+
76
+ # 4. Generate
77
+ # metadata: Target image metadata (optional)
78
+ # cond_metadata: Conditioning image metadata (optional)
79
+
80
+ image = pipe(
81
+ "satellite image of farmland",
82
+ image=control_image,
83
+ metadata=None,
84
+ cond_metadata=None,
85
+ num_inference_steps=30,
86
+ ).images[0]
87
+ ```
config.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ControlNetModel",
3
+ "_diffusers_version": "0.17.0",
4
+ "_name_or_path": "/data/jiabo/diffusionsat/testoutput/checkpoint-1",
5
+ "act_fn": "silu",
6
+ "attention_head_dim": [
7
+ 5,
8
+ 10,
9
+ 20,
10
+ 20
11
+ ],
12
+ "block_out_channels": [
13
+ 320,
14
+ 640,
15
+ 1280,
16
+ 1280
17
+ ],
18
+ "class_embed_type": null,
19
+ "conditioning_embedding_out_channels": [
20
+ 16,
21
+ 32,
22
+ 96,
23
+ 256
24
+ ],
25
+ "conditioning_in_channels": 3,
26
+ "conditioning_scale": 1,
27
+ "controlnet_conditioning_channel_order": "rgb",
28
+ "cross_attention_dim": 1024,
29
+ "down_block_types": [
30
+ "CrossAttnDownBlock2D",
31
+ "CrossAttnDownBlock2D",
32
+ "CrossAttnDownBlock2D",
33
+ "DownBlock2D"
34
+ ],
35
+ "downsample_padding": 1,
36
+ "flip_sin_to_cos": true,
37
+ "freq_shift": 0,
38
+ "global_pool_conditions": false,
39
+ "in_channels": 4,
40
+ "layers_per_block": 2,
41
+ "mid_block_scale_factor": 1,
42
+ "norm_eps": 1e-05,
43
+ "norm_num_groups": 32,
44
+ "num_class_embeds": null,
45
+ "num_metadata": 7,
46
+ "only_cross_attention": false,
47
+ "projection_class_embeddings_input_dim": null,
48
+ "resnet_time_scale_shift": "default",
49
+ "upcast_attention": true,
50
+ "use_linear_projection": true,
51
+ "use_metadata": true
52
+ }
controlnet/config.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": ["controlnet", "ControlNetModel"],
3
+ "_diffusers_version": "0.17.0",
4
+ "act_fn": "silu",
5
+ "attention_head_dim": [
6
+ 5,
7
+ 10,
8
+ 20,
9
+ 20
10
+ ],
11
+ "block_out_channels": [
12
+ 320,
13
+ 640,
14
+ 1280,
15
+ 1280
16
+ ],
17
+ "class_embed_type": null,
18
+ "conditioning_embedding_out_channels": [
19
+ 16,
20
+ 32,
21
+ 96,
22
+ 256
23
+ ],
24
+ "conditioning_in_channels": 3,
25
+ "conditioning_scale": 1,
26
+ "controlnet_conditioning_channel_order": "rgb",
27
+ "cross_attention_dim": 1024,
28
+ "down_block_types": [
29
+ "CrossAttnDownBlock2D",
30
+ "CrossAttnDownBlock2D",
31
+ "CrossAttnDownBlock2D",
32
+ "DownBlock2D"
33
+ ],
34
+ "downsample_padding": 1,
35
+ "flip_sin_to_cos": true,
36
+ "freq_shift": 0,
37
+ "global_pool_conditions": false,
38
+ "in_channels": 4,
39
+ "layers_per_block": 2,
40
+ "mid_block_scale_factor": 1,
41
+ "norm_eps": 1e-05,
42
+ "norm_num_groups": 32,
43
+ "num_class_embeds": null,
44
+ "num_metadata": 7,
45
+ "only_cross_attention": false,
46
+ "projection_class_embeddings_input_dim": null,
47
+ "resnet_time_scale_shift": "default",
48
+ "upcast_attention": true,
49
+ "use_linear_projection": true,
50
+ "use_metadata": true
51
+ }
controlnet/controlnet.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ControlNet wrapper that reuses diffusers implementation and adds metadata."""
2
+ from typing import Any, Dict, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from diffusers.models.controlnets.controlnet import (
9
+ ControlNetConditioningEmbedding as HFConditioningEmbedding,
10
+ ControlNetModel as HFControlNetModel,
11
+ ControlNetOutput,
12
+ zero_module,
13
+ )
14
+ from diffusers.utils import logging
15
+
16
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
17
+
18
+
19
+ class ControlNetConditioningEmbedding(HFConditioningEmbedding):
20
+ """Adapter to allow variable downsample stride via `scale` while reusing upstream layers."""
21
+
22
+ def __init__(
23
+ self,
24
+ conditioning_embedding_channels: int,
25
+ conditioning_channels: int = 3,
26
+ block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
27
+ scale: int = 1,
28
+ ):
29
+ # Initialize base, then optionally override blocks to respect custom stride.
30
+ super().__init__(
31
+ conditioning_embedding_channels=conditioning_embedding_channels,
32
+ conditioning_channels=conditioning_channels,
33
+ block_out_channels=block_out_channels,
34
+ )
35
+ if scale != 1:
36
+ blocks = nn.ModuleList([])
37
+ current_scale = scale
38
+ for i in range(len(block_out_channels) - 1):
39
+ channel_in = block_out_channels[i]
40
+ channel_out = block_out_channels[i + 1]
41
+ blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
42
+ stride = 2 if current_scale < 8 else 1
43
+ blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=stride))
44
+ if current_scale != 8:
45
+ current_scale = int(current_scale * 2)
46
+ self.blocks = blocks
47
+
48
+
49
+ class ControlNetModel(HFControlNetModel):
50
+ """Thin wrapper around `diffusers.ControlNetModel` with metadata embeddings."""
51
+
52
+ def __init__(
53
+ self,
54
+ *args,
55
+ conditioning_in_channels: int = 3,
56
+ conditioning_scale: int = 1,
57
+ use_metadata: bool = True,
58
+ num_metadata: int = 7,
59
+ **kwargs,
60
+ ):
61
+ # Map alias to upstream argument.
62
+ kwargs.setdefault("conditioning_channels", conditioning_in_channels)
63
+
64
+ super().__init__(*args, **kwargs)
65
+
66
+ # Track custom config entries for save/load parity.
67
+ self.register_to_config(
68
+ use_metadata=use_metadata, num_metadata=num_metadata, conditioning_scale=conditioning_scale
69
+ )
70
+
71
+ self.use_metadata = use_metadata
72
+ self.num_metadata = num_metadata
73
+
74
+ if use_metadata:
75
+ timestep_input_dim = self.time_embedding.linear_1.in_features
76
+ time_embed_dim = self.time_embedding.linear_2.out_features
77
+ self.metadata_embedding = nn.ModuleList(
78
+ [
79
+ self._build_metadata_embedding(timestep_input_dim, time_embed_dim)
80
+ for _ in range(num_metadata)
81
+ ]
82
+ )
83
+ else:
84
+ self.metadata_embedding = None
85
+
86
+ # Optionally replace conditioning embedding to honor `conditioning_scale` stride tweaks.
87
+ if conditioning_scale != 1:
88
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
89
+ conditioning_embedding_channels=self.controlnet_cond_embedding.conv_out.out_channels,
90
+ conditioning_channels=conditioning_in_channels,
91
+ block_out_channels=tuple(
92
+ layer.out_channels for layer in self.controlnet_cond_embedding.blocks[1::2]
93
+ ),
94
+ scale=conditioning_scale,
95
+ )
96
+
97
+ @staticmethod
98
+ def _build_metadata_embedding(timestep_input_dim: int, time_embed_dim: int) -> nn.Module:
99
+ from diffusers.models.embeddings import TimestepEmbedding
100
+
101
+ return TimestepEmbedding(timestep_input_dim, time_embed_dim)
102
+
103
+ def _encode_metadata(
104
+ self, metadata: Optional[torch.Tensor], dtype: torch.dtype
105
+ ) -> Optional[torch.Tensor]:
106
+ if self.metadata_embedding is None:
107
+ return None
108
+ if metadata is None:
109
+ raise ValueError("metadata must be provided when use_metadata=True")
110
+ if metadata.dim() != 2 or metadata.shape[1] != self.num_metadata:
111
+ raise ValueError(f"Invalid metadata shape {metadata.shape}, expected (batch, {self.num_metadata})")
112
+
113
+ md_bsz = metadata.shape[0]
114
+ projected = self.time_proj(metadata.view(-1)).view(md_bsz, self.num_metadata, -1).to(dtype=dtype)
115
+
116
+ md_emb = projected.new_zeros((md_bsz, projected.shape[-1]))
117
+ for idx, md_embed in enumerate(self.metadata_embedding):
118
+ md_emb = md_emb + md_embed(projected[:, idx, :])
119
+ return md_emb
120
+
121
+ def forward(
122
+ self,
123
+ sample: torch.Tensor,
124
+ timestep: Union[torch.Tensor, float, int],
125
+ encoder_hidden_states: torch.Tensor,
126
+ controlnet_cond: torch.Tensor,
127
+ conditioning_scale: float = 1.0,
128
+ class_labels: Optional[torch.Tensor] = None,
129
+ timestep_cond: Optional[torch.Tensor] = None,
130
+ attention_mask: Optional[torch.Tensor] = None,
131
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
132
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
133
+ guess_mode: bool = False,
134
+ metadata: Optional[torch.Tensor] = None,
135
+ return_dict: bool = True,
136
+ ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
137
+ # Start from upstream logic, inserting metadata into the timestep embeddings.
138
+
139
+ channel_order = self.config.controlnet_conditioning_channel_order
140
+ if channel_order == "bgr":
141
+ controlnet_cond = torch.flip(controlnet_cond, dims=[1])
142
+ elif channel_order != "rgb":
143
+ raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
144
+
145
+ if attention_mask is not None:
146
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
147
+ attention_mask = attention_mask.unsqueeze(1)
148
+
149
+ timesteps = timestep
150
+ if not torch.is_tensor(timesteps):
151
+ is_mps = sample.device.type == "mps"
152
+ is_npu = sample.device.type == "npu"
153
+ if isinstance(timestep, float):
154
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
155
+ else:
156
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
157
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
158
+ elif len(timesteps.shape) == 0:
159
+ timesteps = timesteps[None].to(sample.device)
160
+ timesteps = timesteps.expand(sample.shape[0])
161
+
162
+ t_emb = self.time_proj(timesteps).to(dtype=sample.dtype)
163
+ emb = self.time_embedding(t_emb, timestep_cond)
164
+
165
+ class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
166
+ if class_emb is not None:
167
+ if self.config.class_embed_type == "timestep":
168
+ class_emb = class_emb.to(dtype=sample.dtype)
169
+ emb = emb + class_emb
170
+
171
+ aug_emb = self.get_aug_embed(
172
+ emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs or {}
173
+ )
174
+ if aug_emb is not None:
175
+ emb = emb + aug_emb
176
+
177
+ md_emb = self._encode_metadata(metadata=metadata, dtype=sample.dtype)
178
+ if md_emb is not None:
179
+ emb = emb + md_emb
180
+
181
+ sample = self.conv_in(sample)
182
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
183
+ sample = sample + controlnet_cond
184
+
185
+ down_block_res_samples = (sample,)
186
+ for downsample_block in self.down_blocks:
187
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
188
+ sample, res_samples = downsample_block(
189
+ hidden_states=sample,
190
+ temb=emb,
191
+ encoder_hidden_states=encoder_hidden_states,
192
+ attention_mask=attention_mask,
193
+ cross_attention_kwargs=cross_attention_kwargs,
194
+ )
195
+ else:
196
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
197
+ down_block_res_samples += res_samples
198
+
199
+ if self.mid_block is not None:
200
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
201
+ sample = self.mid_block(
202
+ sample,
203
+ emb,
204
+ encoder_hidden_states=encoder_hidden_states,
205
+ attention_mask=attention_mask,
206
+ cross_attention_kwargs=cross_attention_kwargs,
207
+ )
208
+ else:
209
+ sample = self.mid_block(sample, emb)
210
+
211
+ controlnet_down_block_res_samples = ()
212
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
213
+ down_block_res_sample = controlnet_block(down_block_res_sample)
214
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
215
+ down_block_res_samples = controlnet_down_block_res_samples
216
+
217
+ mid_block_res_sample = self.controlnet_mid_block(sample)
218
+
219
+ if guess_mode and not self.config.global_pool_conditions:
220
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) * conditioning_scale
221
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
222
+ mid_block_res_sample = mid_block_res_sample * scales[-1]
223
+ else:
224
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
225
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
226
+
227
+ if self.config.global_pool_conditions:
228
+ down_block_res_samples = [
229
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
230
+ ]
231
+ mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
232
+
233
+ if not return_dict:
234
+ return (down_block_res_samples, mid_block_res_sample)
235
+
236
+ return ControlNetOutput(
237
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
238
+ )
controlnet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3bd5f6b9aea04714f331cd94d721c8adb8b378a2774a9805e6f0a369e33aacd7
3
+ size 1514372328
feature_extractor/preprocessor_config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": {
3
+ "height": 224,
4
+ "width": 224
5
+ },
6
+ "do_center_crop": true,
7
+ "do_convert_rgb": true,
8
+ "do_normalize": true,
9
+ "do_rescale": true,
10
+ "do_resize": true,
11
+ "feature_extractor_type": "CLIPFeatureExtractor",
12
+ "image_mean": [
13
+ 0.48145466,
14
+ 0.4578275,
15
+ 0.40821073
16
+ ],
17
+ "image_processor_type": "CLIPImageProcessor",
18
+ "image_std": [
19
+ 0.26862954,
20
+ 0.26130258,
21
+ 0.27577711
22
+ ],
23
+ "resample": 3,
24
+ "rescale_factor": 0.00392156862745098,
25
+ "size": {
26
+ "shortest_edge": 224
27
+ }
28
+ }
model_index.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": ["pipeline_diffusionsat_controlnet", "DiffusionSatControlNetPipeline"],
3
+ "_diffusers_version": "0.17.0",
4
+ "controlnet": [
5
+ "controlnet",
6
+ "ControlNetModel"
7
+ ],
8
+ "feature_extractor": [
9
+ "transformers",
10
+ "CLIPImageProcessor"
11
+ ],
12
+ "requires_safety_checker": false,
13
+ "safety_checker": [
14
+ null,
15
+ null
16
+ ],
17
+ "scheduler": [
18
+ "diffusers",
19
+ "DDIMScheduler"
20
+ ],
21
+ "text_encoder": [
22
+ "transformers",
23
+ "CLIPTextModel"
24
+ ],
25
+ "tokenizer": [
26
+ "transformers",
27
+ "CLIPTokenizer"
28
+ ],
29
+ "unet": [
30
+ "sat_unet",
31
+ "SatUNet"
32
+ ],
33
+ "vae": [
34
+ "diffusers",
35
+ "AutoencoderKL"
36
+ ]
37
+ }
pipeline_diffusionsat.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Self-contained DiffusionSat text-to-image pipeline that can be loaded directly
3
+ from the checkpoint folder without importing the project package.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import Any, Callable, Dict, List, Optional, Union
9
+
10
+ import torch
11
+ from packaging import version
12
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
13
+
14
+ from diffusers.configuration_utils import FrozenDict
15
+ from diffusers.models import AutoencoderKL
16
+ from diffusers.schedulers import KarrasDiffusionSchedulers
17
+ from diffusers.utils import (
18
+ deprecate,
19
+ logging,
20
+ randn_tensor,
21
+ replace_example_docstring,
22
+ is_accelerate_available,
23
+ )
24
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
25
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
26
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
27
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
28
+ StableDiffusionPipeline as DiffusersStableDiffusionPipeline,
29
+ )
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+ EXAMPLE_DOC_STRING = """
34
+ Examples:
35
+ ```py
36
+ >>> import torch
37
+ >>> from diffusers import DiffusionPipeline
38
+
39
+ >>> pipe = DiffusionPipeline.from_pretrained("path/to/ckpt/diffusionsat", torch_dtype=torch.float16)
40
+ >>> pipe = pipe.to("cuda")
41
+
42
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
43
+ >>> image = pipe(prompt).images[0]
44
+ ```
45
+ """
46
+
47
+
48
+ class DiffusionSatPipeline(DiffusionPipeline):
49
+ """
50
+ Pipeline for text-to-image generation using the DiffusionSat UNet with optional metadata.
51
+ """
52
+
53
+ _optional_components = ["safety_checker", "feature_extractor"]
54
+
55
+ def __init__(
56
+ self,
57
+ vae: AutoencoderKL,
58
+ text_encoder: CLIPTextModel,
59
+ tokenizer: CLIPTokenizer,
60
+ unet: Any,
61
+ scheduler: KarrasDiffusionSchedulers,
62
+ safety_checker: StableDiffusionSafetyChecker,
63
+ feature_extractor: CLIPFeatureExtractor,
64
+ requires_safety_checker: bool = True,
65
+ ):
66
+ super().__init__()
67
+
68
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
69
+ deprecation_message = (
70
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
71
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
72
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
73
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
74
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
75
+ " file"
76
+ )
77
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
78
+ new_config = dict(scheduler.config)
79
+ new_config["steps_offset"] = 1
80
+ scheduler._internal_dict = FrozenDict(new_config)
81
+
82
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
83
+ deprecation_message = (
84
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
85
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
86
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
87
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
88
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
89
+ )
90
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
91
+ new_config = dict(scheduler.config)
92
+ new_config["clip_sample"] = False
93
+ scheduler._internal_dict = FrozenDict(new_config)
94
+
95
+ if safety_checker is None and requires_safety_checker:
96
+ logger.warning(
97
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
98
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
99
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
100
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
101
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
102
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
103
+ )
104
+
105
+ if safety_checker is not None and feature_extractor is None:
106
+ raise ValueError(
107
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
108
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
109
+ )
110
+
111
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
112
+ version.parse(unet.config._diffusers_version).base_version
113
+ ) < version.parse("0.9.0.dev0")
114
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
115
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
116
+ deprecation_message = (
117
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
118
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
119
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
120
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
121
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
122
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
123
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
124
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
125
+ " the `unet/config.json` file"
126
+ )
127
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
128
+ new_config = dict(unet.config)
129
+ new_config["sample_size"] = 64
130
+ unet._internal_dict = FrozenDict(new_config)
131
+
132
+ self.register_modules(
133
+ vae=vae,
134
+ text_encoder=text_encoder,
135
+ tokenizer=tokenizer,
136
+ unet=unet,
137
+ scheduler=scheduler,
138
+ safety_checker=safety_checker,
139
+ feature_extractor=feature_extractor,
140
+ )
141
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
142
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
143
+
144
+ # Borrow helper implementations from diffusers' StableDiffusionPipeline for convenience.
145
+ enable_vae_slicing = DiffusersStableDiffusionPipeline.enable_vae_slicing
146
+ disable_vae_slicing = DiffusersStableDiffusionPipeline.disable_vae_slicing
147
+ enable_sequential_cpu_offload = DiffusersStableDiffusionPipeline.enable_sequential_cpu_offload
148
+ _execution_device = DiffusersStableDiffusionPipeline._execution_device
149
+ _encode_prompt = DiffusersStableDiffusionPipeline._encode_prompt
150
+ run_safety_checker = DiffusersStableDiffusionPipeline.run_safety_checker
151
+ decode_latents = DiffusersStableDiffusionPipeline.decode_latents
152
+ prepare_extra_step_kwargs = DiffusersStableDiffusionPipeline.prepare_extra_step_kwargs
153
+ check_inputs = DiffusersStableDiffusionPipeline.check_inputs
154
+ prepare_latents = DiffusersStableDiffusionPipeline.prepare_latents
155
+
156
+ def prepare_metadata(
157
+ self, batch_size, metadata, do_classifier_free_guidance, device, dtype,
158
+ ):
159
+ has_metadata = getattr(self.unet.config, "use_metadata", False)
160
+ num_metadata = getattr(self.unet.config, "num_metadata", 0)
161
+
162
+ if metadata is None and has_metadata and num_metadata > 0:
163
+ metadata = torch.zeros((batch_size, num_metadata), device=device, dtype=dtype)
164
+
165
+ if metadata is None:
166
+ return None
167
+
168
+ md = torch.tensor(metadata) if not torch.is_tensor(metadata) else metadata
169
+ if len(md.shape) == 1:
170
+ md = md.unsqueeze(0).expand(batch_size, -1)
171
+ md = md.to(device=device, dtype=dtype)
172
+
173
+ if do_classifier_free_guidance:
174
+ md = torch.cat([torch.zeros_like(md), md])
175
+
176
+ return md
177
+
178
+ @torch.no_grad()
179
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
180
+ def __call__(
181
+ self,
182
+ prompt: Union[str, List[str]] = None,
183
+ height: Optional[int] = None,
184
+ width: Optional[int] = None,
185
+ num_inference_steps: int = 50,
186
+ guidance_scale: float = 7.5,
187
+ negative_prompt: Optional[Union[str, List[str]]] = None,
188
+ num_images_per_prompt: Optional[int] = 1,
189
+ eta: float = 0.0,
190
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
191
+ latents: Optional[torch.FloatTensor] = None,
192
+ prompt_embeds: Optional[torch.FloatTensor] = None,
193
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
194
+ output_type: Optional[str] = "pil",
195
+ return_dict: bool = True,
196
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
197
+ callback_steps: Optional[int] = 1,
198
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
199
+ metadata: Optional[List[float]] = None,
200
+ ):
201
+ # 0. Default height and width to unet
202
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
203
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
204
+
205
+ # 1. Check inputs. Raise error if not correct
206
+ self.check_inputs(
207
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
208
+ )
209
+
210
+ # 2. Define call parameters
211
+ if prompt is not None and isinstance(prompt, str):
212
+ batch_size = 1
213
+ elif prompt is not None and isinstance(prompt, list):
214
+ batch_size = len(prompt)
215
+ else:
216
+ batch_size = prompt_embeds.shape[0]
217
+
218
+ device = self._execution_device
219
+ do_classifier_free_guidance = guidance_scale > 1.0
220
+
221
+ # 3. Encode input prompt
222
+ prompt_embeds = self._encode_prompt(
223
+ prompt,
224
+ device,
225
+ num_images_per_prompt,
226
+ do_classifier_free_guidance,
227
+ negative_prompt,
228
+ prompt_embeds=prompt_embeds,
229
+ negative_prompt_embeds=negative_prompt_embeds,
230
+ )
231
+
232
+ # 4. Prepare timesteps
233
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
234
+ timesteps = self.scheduler.timesteps
235
+
236
+ # 5. Prepare latent variables
237
+ num_channels_latents = self.unet.in_channels if hasattr(self.unet, "in_channels") else self.unet.config.in_channels
238
+ latents = self.prepare_latents(
239
+ batch_size * num_images_per_prompt,
240
+ num_channels_latents,
241
+ height,
242
+ width,
243
+ prompt_embeds.dtype,
244
+ device,
245
+ generator,
246
+ latents,
247
+ )
248
+
249
+ # 6. Prepare extra step kwargs.
250
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
251
+
252
+ # 6.5: Prepare metadata (auto-zero filled when missing)
253
+ input_metadata = self.prepare_metadata(
254
+ batch_size, metadata, do_classifier_free_guidance, device, prompt_embeds.dtype
255
+ )
256
+ if input_metadata is not None:
257
+ assert input_metadata.shape[-1] == getattr(self.unet.config, "num_metadata", input_metadata.shape[-1])
258
+ assert input_metadata.shape[0] == prompt_embeds.shape[0]
259
+
260
+ # 7. Denoising loop
261
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
262
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
263
+ for i, t in enumerate(timesteps):
264
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
265
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
266
+
267
+ noise_pred = self.unet(
268
+ latent_model_input,
269
+ t,
270
+ metadata=input_metadata,
271
+ encoder_hidden_states=prompt_embeds,
272
+ cross_attention_kwargs=cross_attention_kwargs,
273
+ ).sample
274
+
275
+ if do_classifier_free_guidance:
276
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
277
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
278
+
279
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
280
+
281
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
282
+ progress_bar.update()
283
+ if callback is not None and i % callback_steps == 0:
284
+ callback(i, t, latents)
285
+
286
+ if output_type == "latent":
287
+ image = latents
288
+ has_nsfw_concept = None
289
+ elif output_type == "pil":
290
+ image = self.decode_latents(latents)
291
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
292
+ image = self.numpy_to_pil(image)
293
+ else:
294
+ image = self.decode_latents(latents)
295
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
296
+
297
+ if not return_dict:
298
+ return (image, has_nsfw_concept)
299
+
300
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
301
+
302
+
303
+ __all__ = ["DiffusionSatPipeline"]
pipeline_diffusionsat_controlnet.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Self-contained DiffusionSat ControlNet pipeline that can be loaded directly from
3
+ the checkpoint folder without importing the project package.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import os
9
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
10
+
11
+ import einops
12
+ import numpy as np
13
+ import PIL.Image
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from torch import nn
17
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
18
+
19
+ from diffusers.loaders import TextualInversionLoaderMixin
20
+ from diffusers.models import AutoencoderKL
21
+ from diffusers.schedulers import KarrasDiffusionSchedulers
22
+ from diffusers.utils import (
23
+ PIL_INTERPOLATION,
24
+ logging,
25
+ randn_tensor,
26
+ replace_example_docstring,
27
+ is_accelerate_available,
28
+ is_accelerate_version,
29
+ )
30
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
31
+ from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
32
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
33
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
34
+ StableDiffusionPipeline as DiffusersStableDiffusionPipeline,
35
+ )
36
+ from diffusers.pipelines.controlnet.pipeline_controlnet import (
37
+ StableDiffusionControlNetPipeline as DiffusersControlNetPipeline,
38
+ )
39
+
40
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
+
42
+ EXAMPLE_DOC_STRING = """
43
+ Examples:
44
+ ```py
45
+ >>> # !pip install opencv-python transformers accelerate
46
+ >>> from diffusers import DiffusionPipeline
47
+ >>> from diffusers.utils import load_image
48
+ >>> import numpy as np
49
+ >>> import torch
50
+ >>> import cv2
51
+ >>> from PIL import Image
52
+ >>>
53
+ >>> image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
54
+ >>> image = np.array(image)
55
+ >>> image = cv2.Canny(image, 100, 200)
56
+ >>> image = image[:, :, None]
57
+ >>> image = np.concatenate([image, image, image], axis=2)
58
+ >>> canny_image = Image.fromarray(image)
59
+ >>>
60
+ >>> pipe = DiffusionPipeline.from_pretrained("path/to/ckpt/diffusionsat", torch_dtype=torch.float16)
61
+ >>> pipe = pipe.to("cuda")
62
+ >>> pipe.enable_xformers_memory_efficient_attention()
63
+ >>> generator = torch.manual_seed(0)
64
+ >>> image = pipe(
65
+ ... "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image
66
+ ... ).images[0]
67
+ ```
68
+ """
69
+
70
+
71
+ class DiffusionSatControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
72
+ """
73
+ ControlNet-aware pipeline for DiffusionSat. This is a mostly direct copy of
74
+ the project pipeline to avoid importing the `diffusionsat` package when
75
+ loading from the checkpoint folder. Minimal tweaks:
76
+ - auto-fills metadata/cond_metadata with zeros when the model expects them.
77
+ """
78
+
79
+ _optional_components = ["safety_checker", "feature_extractor"]
80
+
81
+ def __init__(
82
+ self,
83
+ vae: AutoencoderKL,
84
+ text_encoder: CLIPTextModel,
85
+ tokenizer: CLIPTokenizer,
86
+ unet: Any,
87
+ controlnet: Any,
88
+ scheduler: KarrasDiffusionSchedulers,
89
+ safety_checker: StableDiffusionSafetyChecker,
90
+ feature_extractor: CLIPImageProcessor,
91
+ requires_safety_checker: bool = True,
92
+ ):
93
+ super().__init__()
94
+
95
+ if safety_checker is None and requires_safety_checker:
96
+ logger.warning(
97
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
98
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
99
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
100
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
101
+ " it only for use-cases that involve analyzing network behavior or auditing its results."
102
+ )
103
+
104
+ if safety_checker is not None and feature_extractor is None:
105
+ raise ValueError(
106
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
107
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
108
+ )
109
+
110
+ # Support MultiControlNetModel-like objects without importing the project module.
111
+ if isinstance(controlnet, (list, tuple)):
112
+ # defer to diffusers' MultiControlNetModel if available
113
+ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
114
+
115
+ controlnet = MultiControlNetModel(controlnet)
116
+
117
+ self.register_modules(
118
+ vae=vae,
119
+ text_encoder=text_encoder,
120
+ tokenizer=tokenizer,
121
+ unet=unet,
122
+ controlnet=controlnet,
123
+ scheduler=scheduler,
124
+ safety_checker=safety_checker,
125
+ feature_extractor=feature_extractor,
126
+ )
127
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
128
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
129
+
130
+ # Reuse helpers from diffusers baseline pipelines.
131
+ enable_vae_slicing = DiffusersStableDiffusionPipeline.enable_vae_slicing
132
+ disable_vae_slicing = DiffusersStableDiffusionPipeline.disable_vae_slicing
133
+ enable_vae_tiling = DiffusersStableDiffusionPipeline.enable_vae_tiling
134
+ disable_vae_tiling = DiffusersStableDiffusionPipeline.disable_vae_tiling
135
+ enable_sequential_cpu_offload = DiffusersControlNetPipeline.enable_sequential_cpu_offload
136
+ enable_model_cpu_offload = DiffusersControlNetPipeline.enable_model_cpu_offload
137
+ _execution_device = DiffusersStableDiffusionPipeline._execution_device
138
+ _encode_prompt = DiffusersStableDiffusionPipeline._encode_prompt
139
+ run_safety_checker = DiffusersStableDiffusionPipeline.run_safety_checker
140
+ decode_latents = DiffusersStableDiffusionPipeline.decode_latents
141
+ prepare_extra_step_kwargs = DiffusersStableDiffusionPipeline.prepare_extra_step_kwargs
142
+ check_inputs = DiffusersControlNetPipeline.check_inputs
143
+ check_image = DiffusersControlNetPipeline.check_image
144
+ prepare_image = DiffusersControlNetPipeline.prepare_image
145
+ prepare_latents = DiffusersStableDiffusionPipeline.prepare_latents
146
+
147
+ def prepare_metadata(self, batch_size, metadata, ndims, do_classifier_free_guidance, device, dtype):
148
+ has_metadata = getattr(self.unet.config, "use_metadata", False)
149
+ num_metadata = getattr(self.unet.config, "num_metadata", 0)
150
+
151
+ if metadata is None and has_metadata and num_metadata > 0:
152
+ shape = (batch_size, num_metadata) if ndims == 2 else (batch_size, num_metadata, 1)
153
+ metadata = torch.zeros(shape, device=device, dtype=dtype)
154
+
155
+ if metadata is None:
156
+ return None
157
+
158
+ md = torch.as_tensor(metadata)
159
+ if ndims == 2:
160
+ assert (len(md.shape) == 1 and batch_size == 1) or (len(md.shape) == 2 and batch_size > 1)
161
+ if len(md.shape) == 1:
162
+ md = md.unsqueeze(0).expand(batch_size, -1)
163
+ elif ndims == 3:
164
+ assert (len(md.shape) == 2 and batch_size == 1) or (len(md.shape) == 3 and batch_size > 1)
165
+ if len(md.shape) == 2:
166
+ md = md.unsqueeze(0).expand(batch_size, -1, -1)
167
+
168
+ if do_classifier_free_guidance:
169
+ md = torch.cat([torch.zeros_like(md), md])
170
+
171
+ md = md.to(device=device, dtype=dtype)
172
+ return md
173
+
174
+ def _default_height_width(self, height, width, image):
175
+ while isinstance(image, list):
176
+ image = image[0]
177
+
178
+ if height is None:
179
+ if isinstance(image, PIL.Image.Image):
180
+ height = image.height
181
+ elif isinstance(image, torch.Tensor):
182
+ height = image.shape[2]
183
+ height = (height // 8) * 8
184
+
185
+ if width is None:
186
+ if isinstance(image, PIL.Image.Image):
187
+ width = image.width
188
+ elif isinstance(image, torch.Tensor):
189
+ width = image.shape[3]
190
+ width = (width // 8) * 8
191
+
192
+ return height, width
193
+
194
+ # override DiffusionPipeline
195
+ def save_pretrained(
196
+ self,
197
+ save_directory: Union[str, os.PathLike],
198
+ safe_serialization: bool = False,
199
+ variant: Optional[str] = None,
200
+ ):
201
+ # For single or multi controlnet, rely on default save logic.
202
+ super().save_pretrained(save_directory, safe_serialization=safe_serialization, variant=variant)
203
+
204
+ @torch.no_grad()
205
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
206
+ def __call__(
207
+ self,
208
+ prompt: Union[str, List[str]] = None,
209
+ image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None,
210
+ height: Optional[int] = None,
211
+ width: Optional[int] = None,
212
+ num_inference_steps: int = 50,
213
+ guidance_scale: float = 7.5,
214
+ negative_prompt: Optional[Union[str, List[str]]] = None,
215
+ num_images_per_prompt: Optional[int] = 1,
216
+ eta: float = 0.0,
217
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
218
+ latents: Optional[torch.FloatTensor] = None,
219
+ prompt_embeds: Optional[torch.FloatTensor] = None,
220
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
221
+ output_type: Optional[str] = "pil",
222
+ return_dict: bool = True,
223
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
224
+ callback_steps: int = 1,
225
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
226
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
227
+ guess_mode: bool = False,
228
+ metadata: Optional[List[float]] = None,
229
+ cond_metadata: Optional[List[float]] = None,
230
+ is_temporal: bool = False,
231
+ conditioning_downsample: bool = True,
232
+ ):
233
+ # 0. Default height and width to unet
234
+ height, width = self._default_height_width(height, width, image)
235
+ cond_height, cond_width = height, width
236
+ if not conditioning_downsample:
237
+ cond_height, cond_width = height // 8, width // 8
238
+
239
+ # 1. Check inputs. Raise error if not correct
240
+ self.check_inputs(
241
+ prompt,
242
+ image,
243
+ height,
244
+ width,
245
+ callback_steps,
246
+ negative_prompt,
247
+ prompt_embeds,
248
+ negative_prompt_embeds,
249
+ controlnet_conditioning_scale,
250
+ )
251
+
252
+ # 2. Define call parameters
253
+ if prompt is not None and isinstance(prompt, str):
254
+ batch_size = 1
255
+ elif prompt is not None and isinstance(prompt, list):
256
+ batch_size = len(prompt)
257
+ else:
258
+ batch_size = prompt_embeds.shape[0]
259
+
260
+ device = self._execution_device
261
+ do_classifier_free_guidance = guidance_scale > 1.0
262
+
263
+ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
264
+
265
+ if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
266
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)
267
+
268
+ # 3. Encode input prompt
269
+ prompt_embeds = self._encode_prompt(
270
+ prompt,
271
+ device,
272
+ num_images_per_prompt,
273
+ do_classifier_free_guidance,
274
+ negative_prompt,
275
+ prompt_embeds=prompt_embeds,
276
+ negative_prompt_embeds=negative_prompt_embeds,
277
+ )
278
+
279
+ # 4. Prepare image
280
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
281
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
282
+ )
283
+ is_multi_cond = isinstance(image, list)
284
+
285
+ if (
286
+ hasattr(self.controlnet, "controlnet_cond_embedding")
287
+ or is_compiled
288
+ and hasattr(self.controlnet._orig_mod, "controlnet_cond_embedding")
289
+ ):
290
+ image = self.prepare_image(
291
+ image=image,
292
+ width=cond_width,
293
+ height=cond_height,
294
+ batch_size=batch_size * num_images_per_prompt,
295
+ num_images_per_prompt=num_images_per_prompt,
296
+ device=device,
297
+ dtype=self.controlnet.dtype,
298
+ do_classifier_free_guidance=do_classifier_free_guidance,
299
+ guess_mode=guess_mode,
300
+ )
301
+
302
+ # 5. Prepare timesteps
303
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
304
+ timesteps = self.scheduler.timesteps
305
+
306
+ # 6. Prepare latent variables
307
+ num_channels_latents = self.unet.config.in_channels
308
+ latents = self.prepare_latents(
309
+ batch_size * num_images_per_prompt,
310
+ num_channels_latents,
311
+ height,
312
+ width,
313
+ prompt_embeds.dtype,
314
+ device,
315
+ generator,
316
+ latents,
317
+ )
318
+
319
+ # 7. Prepare extra step kwargs.
320
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
321
+
322
+ # CUSTOM metadata handling (auto-zero filled)
323
+ input_metadata = self.prepare_metadata(batch_size, metadata, 2, do_classifier_free_guidance, device, prompt_embeds.dtype)
324
+ ndims_cond = 3 if is_multi_cond else 2
325
+ cond_metadata = self.prepare_metadata(
326
+ batch_size, cond_metadata, ndims_cond, do_classifier_free_guidance, device, prompt_embeds.dtype
327
+ )
328
+ if input_metadata is not None:
329
+ assert len(input_metadata.shape) == 2 and input_metadata.shape[-1] == getattr(self.unet.config, "num_metadata", input_metadata.shape[-1])
330
+ if cond_metadata is not None:
331
+ assert len(cond_metadata.shape) == ndims_cond and cond_metadata.shape[1] == getattr(self.unet.config, "num_metadata", cond_metadata.shape[1])
332
+ if is_multi_cond and not is_temporal and not isinstance(self.controlnet, MultiControlNetModel):
333
+ assert cond_metadata.shape[2] == self.controlnet.controlnet_cond_embedding.conv_in.in_channels / 3
334
+
335
+ if input_metadata is not None:
336
+ assert input_metadata.shape[0] == prompt_embeds.shape[0]
337
+
338
+ if is_temporal:
339
+ num_cond = cond_metadata.shape[-1] if cond_metadata is not None else image.shape[1] // self.controlnet.config.conditioning_in_channels
340
+ image = einops.rearrange(image, 'b (t c) h w -> b c t h w', t=num_cond)
341
+ elif isinstance(self.controlnet, MultiControlNetModel) and cond_metadata is not None:
342
+ num_cond = cond_metadata.shape[-1] if cond_metadata is not None else image.shape[1] // self.controlnet.config.conditioning_in_channels
343
+ image = einops.rearrange(image, 'b (t c) h w -> t b c h w', t=num_cond)
344
+ image = [im for im in image]
345
+ cond_metadata = einops.rearrange(cond_metadata, 'b m t -> t b m')
346
+ cond_metadata = [cond_md for cond_md in cond_metadata]
347
+
348
+ # 8. Denoising loop
349
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
350
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
351
+ for i, t in enumerate(timesteps):
352
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
353
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
354
+
355
+ if guess_mode and do_classifier_free_guidance:
356
+ controlnet_latent_model_input = latents
357
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
358
+ else:
359
+ controlnet_latent_model_input = latent_model_input
360
+ controlnet_prompt_embeds = prompt_embeds
361
+
362
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
363
+ controlnet_latent_model_input,
364
+ t,
365
+ encoder_hidden_states=controlnet_prompt_embeds,
366
+ controlnet_cond=image,
367
+ metadata=input_metadata,
368
+ cond_metadata=cond_metadata,
369
+ conditioning_scale=controlnet_conditioning_scale,
370
+ guess_mode=guess_mode,
371
+ return_dict=False,
372
+ )
373
+
374
+ if guess_mode and do_classifier_free_guidance:
375
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
376
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
377
+
378
+ noise_pred = self.unet(
379
+ latent_model_input,
380
+ t,
381
+ encoder_hidden_states=prompt_embeds,
382
+ metadata=input_metadata,
383
+ cross_attention_kwargs=cross_attention_kwargs,
384
+ down_block_additional_residuals=down_block_res_samples,
385
+ mid_block_additional_residual=mid_block_res_sample,
386
+ return_dict=False,
387
+ )[0]
388
+
389
+ if do_classifier_free_guidance:
390
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
391
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
392
+
393
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
394
+
395
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
396
+ progress_bar.update()
397
+ if callback is not None and i % callback_steps == 0:
398
+ callback(i, t, latents)
399
+
400
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
401
+ self.unet.to("cpu")
402
+ self.controlnet.to("cpu")
403
+ torch.cuda.empty_cache()
404
+
405
+ if output_type == "latent":
406
+ image = latents
407
+ has_nsfw_concept = None
408
+ elif output_type == "pil":
409
+ image = self.decode_latents(latents)
410
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
411
+ image = self.numpy_to_pil(image)
412
+ else:
413
+ image = self.decode_latents(latents)
414
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
415
+
416
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
417
+ self.final_offload_hook.offload()
418
+
419
+ if not return_dict:
420
+ return (image, has_nsfw_concept)
421
+
422
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
423
+
424
+
425
+ __all__ = ["DiffusionSatControlNetPipeline"]
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DDIMScheduler",
3
+ "_diffusers_version": "0.17.0",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.00085,
7
+ "clip_sample": false,
8
+ "clip_sample_range": 1.0,
9
+ "dynamic_thresholding_ratio": 0.995,
10
+ "num_train_timesteps": 1000,
11
+ "prediction_type": "v_prediction",
12
+ "rescale_betas_zero_snr": false,
13
+ "sample_max_value": 1.0,
14
+ "set_alpha_to_one": false,
15
+ "skip_prk_steps": true,
16
+ "steps_offset": 1,
17
+ "thresholding": false,
18
+ "timestep_spacing": "leading",
19
+ "trained_betas": null
20
+ }
text_encoder/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "stabilityai/stable-diffusion-2-1",
3
+ "architectures": [
4
+ "CLIPTextModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dropout": 0.0,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "gelu",
11
+ "hidden_size": 1024,
12
+ "initializer_factor": 1.0,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 4096,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 77,
17
+ "model_type": "clip_text_model",
18
+ "num_attention_heads": 16,
19
+ "num_hidden_layers": 23,
20
+ "pad_token_id": 1,
21
+ "projection_dim": 512,
22
+ "torch_dtype": "float16",
23
+ "transformers_version": "4.31.0",
24
+ "vocab_size": 49408
25
+ }
text_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc1827c465450322616f06dea41596eac7d493f4e95904dcb51f0fc745c4e13f
3
+ size 680820392
tokenizer/config.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ControlNetModel",
3
+ "_diffusers_version": "0.17.0",
4
+ "_name_or_path": "/data/jiabo/diffusionsat/testoutput/checkpoint-1",
5
+ "act_fn": "silu",
6
+ "attention_head_dim": [
7
+ 5,
8
+ 10,
9
+ 20,
10
+ 20
11
+ ],
12
+ "block_out_channels": [
13
+ 320,
14
+ 640,
15
+ 1280,
16
+ 1280
17
+ ],
18
+ "class_embed_type": null,
19
+ "conditioning_embedding_out_channels": [
20
+ 16,
21
+ 32,
22
+ 96,
23
+ 256
24
+ ],
25
+ "conditioning_in_channels": 3,
26
+ "conditioning_scale": 1,
27
+ "controlnet_conditioning_channel_order": "rgb",
28
+ "cross_attention_dim": 1024,
29
+ "down_block_types": [
30
+ "CrossAttnDownBlock2D",
31
+ "CrossAttnDownBlock2D",
32
+ "CrossAttnDownBlock2D",
33
+ "DownBlock2D"
34
+ ],
35
+ "downsample_padding": 1,
36
+ "flip_sin_to_cos": true,
37
+ "freq_shift": 0,
38
+ "global_pool_conditions": false,
39
+ "in_channels": 4,
40
+ "layers_per_block": 2,
41
+ "mid_block_scale_factor": 1,
42
+ "norm_eps": 1e-05,
43
+ "norm_num_groups": 32,
44
+ "num_class_embeds": null,
45
+ "num_metadata": 7,
46
+ "only_cross_attention": false,
47
+ "projection_class_embeddings_input_dim": null,
48
+ "resnet_time_scale_shift": "default",
49
+ "upcast_attention": true,
50
+ "use_linear_projection": true,
51
+ "use_metadata": true
52
+ }
tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "!",
17
+ "unk_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": {
4
+ "__type": "AddedToken",
5
+ "content": "<|startoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false
10
+ },
11
+ "clean_up_tokenization_spaces": true,
12
+ "do_lower_case": true,
13
+ "eos_token": {
14
+ "__type": "AddedToken",
15
+ "content": "<|endoftext|>",
16
+ "lstrip": false,
17
+ "normalized": true,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ },
21
+ "errors": "replace",
22
+ "model_max_length": 77,
23
+ "pad_token": "<|endoftext|>",
24
+ "tokenizer_class": "CLIPTokenizer",
25
+ "unk_token": {
26
+ "__type": "AddedToken",
27
+ "content": "<|endoftext|>",
28
+ "lstrip": false,
29
+ "normalized": true,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ }
33
+ }
tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
unet/config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": ["sat_unet", "SatUNet"],
3
+ "_diffusers_version": "0.17.0",
4
+ "act_fn": "silu",
5
+ "attention_head_dim": [
6
+ 5,
7
+ 10,
8
+ 20,
9
+ 20
10
+ ],
11
+ "block_out_channels": [
12
+ 320,
13
+ 640,
14
+ 1280,
15
+ 1280
16
+ ],
17
+ "center_input_sample": false,
18
+ "class_embed_type": null,
19
+ "conv_in_kernel": 3,
20
+ "conv_out_kernel": 3,
21
+ "cross_attention_dim": 1024,
22
+ "down_block_types": [
23
+ "CrossAttnDownBlock2D",
24
+ "CrossAttnDownBlock2D",
25
+ "CrossAttnDownBlock2D",
26
+ "DownBlock2D"
27
+ ],
28
+ "downsample_padding": 1,
29
+ "dual_cross_attention": false,
30
+ "flip_sin_to_cos": true,
31
+ "freq_shift": 0,
32
+ "in_channels": 4,
33
+ "layers_per_block": 2,
34
+ "mid_block_scale_factor": 1,
35
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
36
+ "norm_eps": 1e-05,
37
+ "norm_num_groups": 32,
38
+ "num_class_embeds": null,
39
+ "num_metadata": 7,
40
+ "only_cross_attention": false,
41
+ "out_channels": 4,
42
+ "resnet_time_scale_shift": "default",
43
+ "sample_size": 96,
44
+ "time_cond_proj_dim": null,
45
+ "time_embedding_type": "positional",
46
+ "timestep_post_act": null,
47
+ "up_block_types": [
48
+ "UpBlock2D",
49
+ "CrossAttnUpBlock2D",
50
+ "CrossAttnUpBlock2D",
51
+ "CrossAttnUpBlock2D"
52
+ ],
53
+ "upcast_attention": true,
54
+ "use_linear_projection": true,
55
+ "use_metadata": true
56
+ }
unet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef6c0264f8eb5085b08e5f16631e1ad8ba078f28d94c902276f8dfc603e3eb80
3
+ size 1760615624
unet/sat_unet.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Satellite UNet wrapper with metadata support on top of diffusers."""
2
+ from typing import Any, Dict, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from diffusers.models.unets.unet_2d_condition import (
8
+ UNet2DConditionModel,
9
+ UNet2DConditionOutput,
10
+ )
11
+ from diffusers.utils import logging
12
+
13
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
14
+
15
+
16
+ class SatUNet(UNet2DConditionModel):
17
+ """Thin wrapper around `diffusers.UNet2DConditionModel` with metadata embeddings."""
18
+
19
+ _supports_gradient_checkpointing = True
20
+
21
+ def __init__(self, *args, use_metadata: bool = True, num_metadata: int = 7, **kwargs):
22
+ super().__init__(*args, **kwargs)
23
+
24
+ # Track custom config entries for save/load parity with the base model.
25
+ self.register_to_config(use_metadata=use_metadata, num_metadata=num_metadata)
26
+
27
+ self.use_metadata = use_metadata
28
+ self.num_metadata = num_metadata
29
+
30
+ if use_metadata:
31
+ # Re-use the same dimensions as the base time embedding.
32
+ timestep_input_dim = self.time_embedding.linear_1.in_features
33
+ time_embed_dim = self.time_embedding.linear_2.out_features
34
+ self.metadata_embedding = nn.ModuleList(
35
+ [self._build_metadata_embedding(timestep_input_dim, time_embed_dim) for _ in range(num_metadata)]
36
+ )
37
+ else:
38
+ self.metadata_embedding = None
39
+
40
+ @staticmethod
41
+ def _build_metadata_embedding(timestep_input_dim: int, time_embed_dim: int) -> nn.Module:
42
+ from diffusers.models.embeddings import TimestepEmbedding
43
+
44
+ return TimestepEmbedding(timestep_input_dim, time_embed_dim)
45
+
46
+ def _encode_metadata(
47
+ self, metadata: Optional[torch.Tensor], dtype: torch.dtype
48
+ ) -> Optional[torch.Tensor]:
49
+ if self.metadata_embedding is None:
50
+ return None
51
+
52
+ if metadata is None:
53
+ raise ValueError("metadata must be provided when use_metadata=True")
54
+
55
+ if metadata.dim() != 2 or metadata.shape[1] != self.num_metadata:
56
+ raise ValueError(f"Invalid metadata shape {metadata.shape}, expected (batch, {self.num_metadata})")
57
+
58
+ md_bsz = metadata.shape[0]
59
+ # Reuse the same projection used for timestep encoding to stay aligned with base embeddings.
60
+ projected = self.time_proj(metadata.view(-1)).view(md_bsz, self.num_metadata, -1).to(dtype=dtype)
61
+
62
+ md_emb = projected.new_zeros((md_bsz, projected.shape[-1]))
63
+ for idx, md_embed in enumerate(self.metadata_embedding):
64
+ md_emb = md_emb + md_embed(projected[:, idx, :])
65
+
66
+ return md_emb
67
+
68
+ def forward(
69
+ self,
70
+ sample: torch.Tensor,
71
+ timestep: Union[torch.Tensor, float, int],
72
+ encoder_hidden_states: torch.Tensor,
73
+ class_labels: Optional[torch.Tensor] = None,
74
+ timestep_cond: Optional[torch.Tensor] = None,
75
+ attention_mask: Optional[torch.Tensor] = None,
76
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
77
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
78
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
79
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
80
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
81
+ encoder_attention_mask: Optional[torch.Tensor] = None,
82
+ metadata: Optional[torch.Tensor] = None,
83
+ return_dict: bool = True,
84
+ ) -> Union[UNet2DConditionOutput, Tuple]:
85
+ # Largely mirrors `UNet2DConditionModel.forward` with a metadata injection on the timestep embedding.
86
+
87
+ default_overall_up_factor = 2**self.num_upsamplers
88
+ forward_upsample_size = False
89
+ upsample_size = None
90
+
91
+ for dim in sample.shape[-2:]:
92
+ if dim % default_overall_up_factor != 0:
93
+ forward_upsample_size = True
94
+ break
95
+
96
+ if attention_mask is not None:
97
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
98
+ attention_mask = attention_mask.unsqueeze(1)
99
+
100
+ if encoder_attention_mask is not None:
101
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
102
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
103
+
104
+ if self.config.center_input_sample:
105
+ sample = 2 * sample - 1.0
106
+
107
+ t_emb = self.get_time_embed(sample=sample, timestep=timestep)
108
+ emb = self.time_embedding(t_emb, timestep_cond)
109
+
110
+ class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
111
+ if class_emb is not None:
112
+ if self.config.class_embeddings_concat:
113
+ emb = torch.cat([emb, class_emb], dim=-1)
114
+ else:
115
+ emb = emb + class_emb
116
+
117
+ aug_emb = self.get_aug_embed(
118
+ emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs or {}
119
+ )
120
+ if self.config.addition_embed_type == "image_hint" and aug_emb is not None:
121
+ aug_emb, hint = aug_emb
122
+ sample = torch.cat([sample, hint], dim=1)
123
+
124
+ emb = emb + aug_emb if aug_emb is not None else emb
125
+
126
+ md_emb = self._encode_metadata(metadata=metadata, dtype=sample.dtype)
127
+ if md_emb is not None:
128
+ emb = emb + md_emb
129
+
130
+ if self.time_embed_act is not None:
131
+ emb = self.time_embed_act(emb)
132
+
133
+ encoder_hidden_states = self.process_encoder_hidden_states(
134
+ encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs or {}
135
+ )
136
+
137
+ sample = self.conv_in(sample)
138
+
139
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
140
+ cross_attention_kwargs = cross_attention_kwargs.copy()
141
+ gligen_args = cross_attention_kwargs.pop("gligen")
142
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
143
+
144
+ if cross_attention_kwargs is not None:
145
+ cross_attention_kwargs = cross_attention_kwargs.copy()
146
+ lora_scale = cross_attention_kwargs.pop("scale", 1.0)
147
+ else:
148
+ lora_scale = 1.0
149
+
150
+ from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers, deprecate
151
+
152
+ if USE_PEFT_BACKEND:
153
+ scale_lora_layers(self, lora_scale)
154
+
155
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
156
+ is_adapter = down_intrablock_additional_residuals is not None
157
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
158
+ deprecate(
159
+ "T2I should not use down_block_additional_residuals",
160
+ "1.3.0",
161
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated "
162
+ "and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used "
163
+ "for ControlNet. Please use `down_intrablock_additional_residuals` instead.",
164
+ standard_warn=False,
165
+ )
166
+ down_intrablock_additional_residuals = down_block_additional_residuals
167
+ is_adapter = True
168
+
169
+ down_block_res_samples = (sample,)
170
+ for downsample_block in self.down_blocks:
171
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
172
+ additional_residuals: Dict[str, torch.Tensor] = {}
173
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
174
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
175
+
176
+ sample, res_samples = downsample_block(
177
+ hidden_states=sample,
178
+ temb=emb,
179
+ encoder_hidden_states=encoder_hidden_states,
180
+ attention_mask=attention_mask,
181
+ cross_attention_kwargs=cross_attention_kwargs,
182
+ encoder_attention_mask=encoder_attention_mask,
183
+ **additional_residuals,
184
+ )
185
+ else:
186
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
187
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
188
+ sample += down_intrablock_additional_residuals.pop(0)
189
+
190
+ down_block_res_samples += res_samples
191
+
192
+ if is_controlnet:
193
+ new_down_block_res_samples = ()
194
+
195
+ for down_block_res_sample, down_block_additional_residual in zip(
196
+ down_block_res_samples, down_block_additional_residuals
197
+ ):
198
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
199
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
200
+
201
+ down_block_res_samples = new_down_block_res_samples
202
+
203
+ if self.mid_block is not None:
204
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
205
+ sample = self.mid_block(
206
+ sample,
207
+ emb,
208
+ encoder_hidden_states=encoder_hidden_states,
209
+ attention_mask=attention_mask,
210
+ cross_attention_kwargs=cross_attention_kwargs,
211
+ encoder_attention_mask=encoder_attention_mask,
212
+ )
213
+ else:
214
+ sample = self.mid_block(sample, emb)
215
+
216
+ if (
217
+ is_adapter
218
+ and len(down_intrablock_additional_residuals) > 0
219
+ and sample.shape == down_intrablock_additional_residuals[0].shape
220
+ ):
221
+ sample += down_intrablock_additional_residuals.pop(0)
222
+
223
+ if is_controlnet:
224
+ sample = sample + mid_block_additional_residual
225
+
226
+ for i, upsample_block in enumerate(self.up_blocks):
227
+ is_final_block = i == len(self.up_blocks) - 1
228
+
229
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
230
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
231
+
232
+ if not is_final_block and forward_upsample_size:
233
+ upsample_size = down_block_res_samples[-1].shape[2:]
234
+
235
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
236
+ sample = upsample_block(
237
+ hidden_states=sample,
238
+ temb=emb,
239
+ res_hidden_states_tuple=res_samples,
240
+ encoder_hidden_states=encoder_hidden_states,
241
+ cross_attention_kwargs=cross_attention_kwargs,
242
+ upsample_size=upsample_size,
243
+ attention_mask=attention_mask,
244
+ encoder_attention_mask=encoder_attention_mask,
245
+ )
246
+ else:
247
+ sample = upsample_block(
248
+ hidden_states=sample,
249
+ temb=emb,
250
+ res_hidden_states_tuple=res_samples,
251
+ upsample_size=upsample_size,
252
+ )
253
+
254
+ if self.conv_norm_out:
255
+ sample = self.conv_norm_out(sample)
256
+ sample = self.conv_act(sample)
257
+ sample = self.conv_out(sample)
258
+
259
+ if USE_PEFT_BACKEND:
260
+ unscale_lora_layers(self, lora_scale)
261
+
262
+ if not return_dict:
263
+ return (sample,)
264
+
265
+ return UNet2DConditionOutput(sample=sample)
vae/config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.17.0",
4
+ "_name_or_path": "stabilityai/stable-diffusion-2-1",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "in_channels": 3,
19
+ "latent_channels": 4,
20
+ "layers_per_block": 2,
21
+ "norm_num_groups": 32,
22
+ "out_channels": 3,
23
+ "sample_size": 768,
24
+ "scaling_factor": 0.18215,
25
+ "up_block_types": [
26
+ "UpDecoderBlock2D",
27
+ "UpDecoderBlock2D",
28
+ "UpDecoderBlock2D",
29
+ "UpDecoderBlock2D"
30
+ ]
31
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e4c08995484ee61270175e9e7a072b66a6e4eeb5f0c266667fe1f45b90daf9a
3
+ size 167335342