BiliSakura commited on
Commit
0339cee
·
verified ·
1 Parent(s): 4907452

Update all files for DiffusionSat-Single-256

Browse files
Files changed (1) hide show
  1. pipeline_diffusionsat.py +303 -0
pipeline_diffusionsat.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Self-contained DiffusionSat text-to-image pipeline that can be loaded directly
3
+ from the checkpoint folder without importing the project package.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import Any, Callable, Dict, List, Optional, Union
9
+
10
+ import torch
11
+ from packaging import version
12
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
13
+
14
+ from diffusers.configuration_utils import FrozenDict
15
+ from diffusers.models import AutoencoderKL
16
+ from diffusers.schedulers import KarrasDiffusionSchedulers
17
+ from diffusers.utils import (
18
+ deprecate,
19
+ logging,
20
+ randn_tensor,
21
+ replace_example_docstring,
22
+ is_accelerate_available,
23
+ )
24
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
25
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
26
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
27
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
28
+ StableDiffusionPipeline as DiffusersStableDiffusionPipeline,
29
+ )
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+ EXAMPLE_DOC_STRING = """
34
+ Examples:
35
+ ```py
36
+ >>> import torch
37
+ >>> from diffusers import DiffusionPipeline
38
+
39
+ >>> pipe = DiffusionPipeline.from_pretrained("path/to/ckpt/diffusionsat", torch_dtype=torch.float16)
40
+ >>> pipe = pipe.to("cuda")
41
+
42
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
43
+ >>> image = pipe(prompt).images[0]
44
+ ```
45
+ """
46
+
47
+
48
+ class DiffusionSatPipeline(DiffusionPipeline):
49
+ """
50
+ Pipeline for text-to-image generation using the DiffusionSat UNet with optional metadata.
51
+ """
52
+
53
+ _optional_components = ["safety_checker", "feature_extractor"]
54
+
55
+ def __init__(
56
+ self,
57
+ vae: AutoencoderKL,
58
+ text_encoder: CLIPTextModel,
59
+ tokenizer: CLIPTokenizer,
60
+ unet: Any,
61
+ scheduler: KarrasDiffusionSchedulers,
62
+ safety_checker: StableDiffusionSafetyChecker,
63
+ feature_extractor: CLIPFeatureExtractor,
64
+ requires_safety_checker: bool = True,
65
+ ):
66
+ super().__init__()
67
+
68
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
69
+ deprecation_message = (
70
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
71
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
72
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
73
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
74
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
75
+ " file"
76
+ )
77
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
78
+ new_config = dict(scheduler.config)
79
+ new_config["steps_offset"] = 1
80
+ scheduler._internal_dict = FrozenDict(new_config)
81
+
82
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
83
+ deprecation_message = (
84
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
85
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
86
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
87
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
88
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
89
+ )
90
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
91
+ new_config = dict(scheduler.config)
92
+ new_config["clip_sample"] = False
93
+ scheduler._internal_dict = FrozenDict(new_config)
94
+
95
+ if safety_checker is None and requires_safety_checker:
96
+ logger.warning(
97
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
98
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
99
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
100
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
101
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
102
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
103
+ )
104
+
105
+ if safety_checker is not None and feature_extractor is None:
106
+ raise ValueError(
107
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
108
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
109
+ )
110
+
111
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
112
+ version.parse(unet.config._diffusers_version).base_version
113
+ ) < version.parse("0.9.0.dev0")
114
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
115
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
116
+ deprecation_message = (
117
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
118
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
119
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
120
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
121
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
122
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
123
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
124
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
125
+ " the `unet/config.json` file"
126
+ )
127
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
128
+ new_config = dict(unet.config)
129
+ new_config["sample_size"] = 64
130
+ unet._internal_dict = FrozenDict(new_config)
131
+
132
+ self.register_modules(
133
+ vae=vae,
134
+ text_encoder=text_encoder,
135
+ tokenizer=tokenizer,
136
+ unet=unet,
137
+ scheduler=scheduler,
138
+ safety_checker=safety_checker,
139
+ feature_extractor=feature_extractor,
140
+ )
141
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
142
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
143
+
144
+ # Borrow helper implementations from diffusers' StableDiffusionPipeline for convenience.
145
+ enable_vae_slicing = DiffusersStableDiffusionPipeline.enable_vae_slicing
146
+ disable_vae_slicing = DiffusersStableDiffusionPipeline.disable_vae_slicing
147
+ enable_sequential_cpu_offload = DiffusersStableDiffusionPipeline.enable_sequential_cpu_offload
148
+ _execution_device = DiffusersStableDiffusionPipeline._execution_device
149
+ _encode_prompt = DiffusersStableDiffusionPipeline._encode_prompt
150
+ run_safety_checker = DiffusersStableDiffusionPipeline.run_safety_checker
151
+ decode_latents = DiffusersStableDiffusionPipeline.decode_latents
152
+ prepare_extra_step_kwargs = DiffusersStableDiffusionPipeline.prepare_extra_step_kwargs
153
+ check_inputs = DiffusersStableDiffusionPipeline.check_inputs
154
+ prepare_latents = DiffusersStableDiffusionPipeline.prepare_latents
155
+
156
+ def prepare_metadata(
157
+ self, batch_size, metadata, do_classifier_free_guidance, device, dtype,
158
+ ):
159
+ has_metadata = getattr(self.unet.config, "use_metadata", False)
160
+ num_metadata = getattr(self.unet.config, "num_metadata", 0)
161
+
162
+ if metadata is None and has_metadata and num_metadata > 0:
163
+ metadata = torch.zeros((batch_size, num_metadata), device=device, dtype=dtype)
164
+
165
+ if metadata is None:
166
+ return None
167
+
168
+ md = torch.tensor(metadata) if not torch.is_tensor(metadata) else metadata
169
+ if len(md.shape) == 1:
170
+ md = md.unsqueeze(0).expand(batch_size, -1)
171
+ md = md.to(device=device, dtype=dtype)
172
+
173
+ if do_classifier_free_guidance:
174
+ md = torch.cat([torch.zeros_like(md), md])
175
+
176
+ return md
177
+
178
+ @torch.no_grad()
179
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
180
+ def __call__(
181
+ self,
182
+ prompt: Union[str, List[str]] = None,
183
+ height: Optional[int] = None,
184
+ width: Optional[int] = None,
185
+ num_inference_steps: int = 50,
186
+ guidance_scale: float = 7.5,
187
+ negative_prompt: Optional[Union[str, List[str]]] = None,
188
+ num_images_per_prompt: Optional[int] = 1,
189
+ eta: float = 0.0,
190
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
191
+ latents: Optional[torch.FloatTensor] = None,
192
+ prompt_embeds: Optional[torch.FloatTensor] = None,
193
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
194
+ output_type: Optional[str] = "pil",
195
+ return_dict: bool = True,
196
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
197
+ callback_steps: Optional[int] = 1,
198
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
199
+ metadata: Optional[List[float]] = None,
200
+ ):
201
+ # 0. Default height and width to unet
202
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
203
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
204
+
205
+ # 1. Check inputs. Raise error if not correct
206
+ self.check_inputs(
207
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
208
+ )
209
+
210
+ # 2. Define call parameters
211
+ if prompt is not None and isinstance(prompt, str):
212
+ batch_size = 1
213
+ elif prompt is not None and isinstance(prompt, list):
214
+ batch_size = len(prompt)
215
+ else:
216
+ batch_size = prompt_embeds.shape[0]
217
+
218
+ device = self._execution_device
219
+ do_classifier_free_guidance = guidance_scale > 1.0
220
+
221
+ # 3. Encode input prompt
222
+ prompt_embeds = self._encode_prompt(
223
+ prompt,
224
+ device,
225
+ num_images_per_prompt,
226
+ do_classifier_free_guidance,
227
+ negative_prompt,
228
+ prompt_embeds=prompt_embeds,
229
+ negative_prompt_embeds=negative_prompt_embeds,
230
+ )
231
+
232
+ # 4. Prepare timesteps
233
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
234
+ timesteps = self.scheduler.timesteps
235
+
236
+ # 5. Prepare latent variables
237
+ num_channels_latents = self.unet.in_channels if hasattr(self.unet, "in_channels") else self.unet.config.in_channels
238
+ latents = self.prepare_latents(
239
+ batch_size * num_images_per_prompt,
240
+ num_channels_latents,
241
+ height,
242
+ width,
243
+ prompt_embeds.dtype,
244
+ device,
245
+ generator,
246
+ latents,
247
+ )
248
+
249
+ # 6. Prepare extra step kwargs.
250
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
251
+
252
+ # 6.5: Prepare metadata (auto-zero filled when missing)
253
+ input_metadata = self.prepare_metadata(
254
+ batch_size, metadata, do_classifier_free_guidance, device, prompt_embeds.dtype
255
+ )
256
+ if input_metadata is not None:
257
+ assert input_metadata.shape[-1] == getattr(self.unet.config, "num_metadata", input_metadata.shape[-1])
258
+ assert input_metadata.shape[0] == prompt_embeds.shape[0]
259
+
260
+ # 7. Denoising loop
261
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
262
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
263
+ for i, t in enumerate(timesteps):
264
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
265
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
266
+
267
+ noise_pred = self.unet(
268
+ latent_model_input,
269
+ t,
270
+ metadata=input_metadata,
271
+ encoder_hidden_states=prompt_embeds,
272
+ cross_attention_kwargs=cross_attention_kwargs,
273
+ ).sample
274
+
275
+ if do_classifier_free_guidance:
276
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
277
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
278
+
279
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
280
+
281
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
282
+ progress_bar.update()
283
+ if callback is not None and i % callback_steps == 0:
284
+ callback(i, t, latents)
285
+
286
+ if output_type == "latent":
287
+ image = latents
288
+ has_nsfw_concept = None
289
+ elif output_type == "pil":
290
+ image = self.decode_latents(latents)
291
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
292
+ image = self.numpy_to_pil(image)
293
+ else:
294
+ image = self.decode_latents(latents)
295
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
296
+
297
+ if not return_dict:
298
+ return (image, has_nsfw_concept)
299
+
300
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
301
+
302
+
303
+ __all__ = ["DiffusionSatPipeline"]