BiliSakura commited on
Commit
d4d085b
·
verified ·
1 Parent(s): 596f048

Update all files for DiffusionSat-Single-512

Browse files
Files changed (1) hide show
  1. pipeline_diffusionsat.py +307 -0
pipeline_diffusionsat.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Self-contained DiffusionSat text-to-image pipeline that can be loaded directly
3
+ from the checkpoint folder without importing the project package.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import Any, Callable, Dict, List, Optional, Union
9
+
10
+ import torch
11
+ from packaging import version
12
+ from transformers import CLIPTextModel, CLIPTokenizer
13
+ try:
14
+ from transformers import CLIPImageProcessor
15
+ except ImportError:
16
+ from transformers import CLIPFeatureExtractor as CLIPImageProcessor
17
+
18
+ from diffusers.configuration_utils import FrozenDict
19
+ from diffusers.models import AutoencoderKL
20
+ from diffusers.schedulers import KarrasDiffusionSchedulers
21
+ from diffusers.utils import (
22
+ deprecate,
23
+ logging,
24
+ randn_tensor,
25
+ replace_example_docstring,
26
+ is_accelerate_available,
27
+ )
28
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
29
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
30
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
31
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
32
+ StableDiffusionPipeline as DiffusersStableDiffusionPipeline,
33
+ )
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+ EXAMPLE_DOC_STRING = """
38
+ Examples:
39
+ ```py
40
+ >>> import torch
41
+ >>> from diffusers import DiffusionPipeline
42
+
43
+ >>> pipe = DiffusionPipeline.from_pretrained("path/to/ckpt/diffusionsat", torch_dtype=torch.float16)
44
+ >>> pipe = pipe.to("cuda")
45
+
46
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
47
+ >>> image = pipe(prompt).images[0]
48
+ ```
49
+ """
50
+
51
+
52
+ class DiffusionSatPipeline(DiffusionPipeline):
53
+ """
54
+ Pipeline for text-to-image generation using the DiffusionSat UNet with optional metadata.
55
+ """
56
+
57
+ _optional_components = ["safety_checker", "feature_extractor"]
58
+
59
+ def __init__(
60
+ self,
61
+ vae: AutoencoderKL,
62
+ text_encoder: CLIPTextModel,
63
+ tokenizer: CLIPTokenizer,
64
+ unet: Any,
65
+ scheduler: KarrasDiffusionSchedulers,
66
+ safety_checker: StableDiffusionSafetyChecker,
67
+ feature_extractor: CLIPImageProcessor,
68
+ requires_safety_checker: bool = True,
69
+ ):
70
+ super().__init__()
71
+
72
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
73
+ deprecation_message = (
74
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
75
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
76
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
77
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
78
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
79
+ " file"
80
+ )
81
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
82
+ new_config = dict(scheduler.config)
83
+ new_config["steps_offset"] = 1
84
+ scheduler._internal_dict = FrozenDict(new_config)
85
+
86
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
87
+ deprecation_message = (
88
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
89
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
90
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
91
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
92
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
93
+ )
94
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
95
+ new_config = dict(scheduler.config)
96
+ new_config["clip_sample"] = False
97
+ scheduler._internal_dict = FrozenDict(new_config)
98
+
99
+ if safety_checker is None and requires_safety_checker:
100
+ logger.warning(
101
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
102
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
103
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
104
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
105
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
106
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
107
+ )
108
+
109
+ if safety_checker is not None and feature_extractor is None:
110
+ raise ValueError(
111
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
112
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
113
+ )
114
+
115
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
116
+ version.parse(unet.config._diffusers_version).base_version
117
+ ) < version.parse("0.9.0.dev0")
118
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
119
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
120
+ deprecation_message = (
121
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
122
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
123
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
124
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
125
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
126
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
127
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
128
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
129
+ " the `unet/config.json` file"
130
+ )
131
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
132
+ new_config = dict(unet.config)
133
+ new_config["sample_size"] = 64
134
+ unet._internal_dict = FrozenDict(new_config)
135
+
136
+ self.register_modules(
137
+ vae=vae,
138
+ text_encoder=text_encoder,
139
+ tokenizer=tokenizer,
140
+ unet=unet,
141
+ scheduler=scheduler,
142
+ safety_checker=safety_checker,
143
+ feature_extractor=feature_extractor,
144
+ )
145
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
146
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
147
+
148
+ # Borrow helper implementations from diffusers' StableDiffusionPipeline for convenience.
149
+ enable_vae_slicing = DiffusersStableDiffusionPipeline.enable_vae_slicing
150
+ disable_vae_slicing = DiffusersStableDiffusionPipeline.disable_vae_slicing
151
+ enable_sequential_cpu_offload = DiffusersStableDiffusionPipeline.enable_sequential_cpu_offload
152
+ _execution_device = DiffusersStableDiffusionPipeline._execution_device
153
+ _encode_prompt = DiffusersStableDiffusionPipeline._encode_prompt
154
+ run_safety_checker = DiffusersStableDiffusionPipeline.run_safety_checker
155
+ decode_latents = DiffusersStableDiffusionPipeline.decode_latents
156
+ prepare_extra_step_kwargs = DiffusersStableDiffusionPipeline.prepare_extra_step_kwargs
157
+ check_inputs = DiffusersStableDiffusionPipeline.check_inputs
158
+ prepare_latents = DiffusersStableDiffusionPipeline.prepare_latents
159
+
160
+ def prepare_metadata(
161
+ self, batch_size, metadata, do_classifier_free_guidance, device, dtype,
162
+ ):
163
+ has_metadata = getattr(self.unet.config, "use_metadata", False)
164
+ num_metadata = getattr(self.unet.config, "num_metadata", 0)
165
+
166
+ if metadata is None and has_metadata and num_metadata > 0:
167
+ metadata = torch.zeros((batch_size, num_metadata), device=device, dtype=dtype)
168
+
169
+ if metadata is None:
170
+ return None
171
+
172
+ md = torch.tensor(metadata) if not torch.is_tensor(metadata) else metadata
173
+ if len(md.shape) == 1:
174
+ md = md.unsqueeze(0).expand(batch_size, -1)
175
+ md = md.to(device=device, dtype=dtype)
176
+
177
+ if do_classifier_free_guidance:
178
+ md = torch.cat([torch.zeros_like(md), md])
179
+
180
+ return md
181
+
182
+ @torch.no_grad()
183
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
184
+ def __call__(
185
+ self,
186
+ prompt: Union[str, List[str]] = None,
187
+ height: Optional[int] = None,
188
+ width: Optional[int] = None,
189
+ num_inference_steps: int = 50,
190
+ guidance_scale: float = 7.5,
191
+ negative_prompt: Optional[Union[str, List[str]]] = None,
192
+ num_images_per_prompt: Optional[int] = 1,
193
+ eta: float = 0.0,
194
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
195
+ latents: Optional[torch.FloatTensor] = None,
196
+ prompt_embeds: Optional[torch.FloatTensor] = None,
197
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
198
+ output_type: Optional[str] = "pil",
199
+ return_dict: bool = True,
200
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
201
+ callback_steps: Optional[int] = 1,
202
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
203
+ metadata: Optional[List[float]] = None,
204
+ ):
205
+ # 0. Default height and width to unet
206
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
207
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
208
+
209
+ # 1. Check inputs. Raise error if not correct
210
+ self.check_inputs(
211
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
212
+ )
213
+
214
+ # 2. Define call parameters
215
+ if prompt is not None and isinstance(prompt, str):
216
+ batch_size = 1
217
+ elif prompt is not None and isinstance(prompt, list):
218
+ batch_size = len(prompt)
219
+ else:
220
+ batch_size = prompt_embeds.shape[0]
221
+
222
+ device = self._execution_device
223
+ do_classifier_free_guidance = guidance_scale > 1.0
224
+
225
+ # 3. Encode input prompt
226
+ prompt_embeds = self._encode_prompt(
227
+ prompt,
228
+ device,
229
+ num_images_per_prompt,
230
+ do_classifier_free_guidance,
231
+ negative_prompt,
232
+ prompt_embeds=prompt_embeds,
233
+ negative_prompt_embeds=negative_prompt_embeds,
234
+ )
235
+
236
+ # 4. Prepare timesteps
237
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
238
+ timesteps = self.scheduler.timesteps
239
+
240
+ # 5. Prepare latent variables
241
+ num_channels_latents = self.unet.in_channels if hasattr(self.unet, "in_channels") else self.unet.config.in_channels
242
+ latents = self.prepare_latents(
243
+ batch_size * num_images_per_prompt,
244
+ num_channels_latents,
245
+ height,
246
+ width,
247
+ prompt_embeds.dtype,
248
+ device,
249
+ generator,
250
+ latents,
251
+ )
252
+
253
+ # 6. Prepare extra step kwargs.
254
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
255
+
256
+ # 6.5: Prepare metadata (auto-zero filled when missing)
257
+ input_metadata = self.prepare_metadata(
258
+ batch_size, metadata, do_classifier_free_guidance, device, prompt_embeds.dtype
259
+ )
260
+ if input_metadata is not None:
261
+ assert input_metadata.shape[-1] == getattr(self.unet.config, "num_metadata", input_metadata.shape[-1])
262
+ assert input_metadata.shape[0] == prompt_embeds.shape[0]
263
+
264
+ # 7. Denoising loop
265
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
266
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
267
+ for i, t in enumerate(timesteps):
268
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
269
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
270
+
271
+ noise_pred = self.unet(
272
+ latent_model_input,
273
+ t,
274
+ metadata=input_metadata,
275
+ encoder_hidden_states=prompt_embeds,
276
+ cross_attention_kwargs=cross_attention_kwargs,
277
+ ).sample
278
+
279
+ if do_classifier_free_guidance:
280
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
281
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
282
+
283
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
284
+
285
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
286
+ progress_bar.update()
287
+ if callback is not None and i % callback_steps == 0:
288
+ callback(i, t, latents)
289
+
290
+ if output_type == "latent":
291
+ image = latents
292
+ has_nsfw_concept = None
293
+ elif output_type == "pil":
294
+ image = self.decode_latents(latents)
295
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
296
+ image = self.numpy_to_pil(image)
297
+ else:
298
+ image = self.decode_latents(latents)
299
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
300
+
301
+ if not return_dict:
302
+ return (image, has_nsfw_concept)
303
+
304
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
305
+
306
+
307
+ __all__ = ["DiffusionSatPipeline"]