suvadityamuk commited on
Commit
63334c4
·
1 Parent(s): f2c6e18

Upload stable_diffusion_comparison.py

Browse files
Files changed (1) hide show
  1. stable_diffusion_comparison.py +365 -0
stable_diffusion_comparison.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Dict, List, Optional, Union
2
+
3
+ import torch
4
+
5
+ from diffusers import DiffusionPipeline, StableDiffusionPipeline
6
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
7
+
8
+
9
+ pipe1_model_id = "CompVis/stable-diffusion-v1-1"
10
+ pipe2_model_id = "CompVis/stable-diffusion-v1-2"
11
+ pipe3_model_id = "CompVis/stable-diffusion-v1-3"
12
+ pipe4_model_id = "CompVis/stable-diffusion-v1-4"
13
+
14
+
15
+ class StableDiffusionComparisonPipeline(DiffusionPipeline):
16
+ r"""
17
+ Pipeline for parallel comparison of Stable Diffusion v1-v4
18
+ This pipeline inherits from DiffusionPipeline and depends on the use of an Auth Token for
19
+ downloading pre-trained checkpoints from Hugging Face Hub.
20
+ Args:
21
+ pipe1 ('StableDiffusionPipeline' or 'str', optional):
22
+ A Stable Diffusion Pipeline prepared from the SD1.1 Checkpoints on Hugging Face Hub
23
+ pipe2 ('StableDiffusionPipeline' or 'str', optional):
24
+ A Stable Diffusion Pipeline prepared from the SD1.2 Checkpoints on Hugging Face Hub
25
+ pipe3 ('StableDiffusionPipeline' or 'str', optional):
26
+ A Stable Diffusion Pipeline prepared from the SD1.3 Checkpoints on Hugging Face Hub
27
+ pipe4 ('StableDiffusionPipeline' or 'str', optional):
28
+ A Stable Diffusion Pipeline prepared from the SD1.4 Checkpoints on Hugging Face Hub
29
+ """
30
+
31
+ def _init_(
32
+ self,
33
+ sd1_1: Union[StableDiffusionPipeline, str],
34
+ sd1_2: Union[StableDiffusionPipeline, str],
35
+ sd1_3: Union[StableDiffusionPipeline, str],
36
+ sd1_4: Union[StableDiffusionPipeline, str],
37
+ ):
38
+ super()._init_()
39
+
40
+ if not isinstance(sd1_1, StableDiffusionPipeline):
41
+ self.pipe1 = StableDiffusionPipeline.from_pretrained(
42
+ pipe1_model_id, torch_dtype=torch.float16, revision="fp16", use_auth_token=True
43
+ )
44
+ else:
45
+ self.pipe1 = sd1_1
46
+ if not isinstance(sd1_2, StableDiffusionPipeline):
47
+ self.pipe2 = StableDiffusionPipeline.from_pretrained(
48
+ pipe2_model_id, torch_dtype=torch.float16, revision="fp16", use_auth_token=True
49
+ )
50
+ else:
51
+ self.pipe2 = sd1_2
52
+ if not isinstance(sd1_3, StableDiffusionPipeline):
53
+ self.pipe3 = StableDiffusionPipeline.from_pretrained(
54
+ pipe3_model_id, torch_dtype=torch.float16, revision="fp16", use_auth_token=True
55
+ )
56
+ else:
57
+ self.pipe3 = sd1_3
58
+ if not isinstance(sd1_4, StableDiffusionPipeline):
59
+ self.pipe4 = StableDiffusionPipeline.from_pretrained(
60
+ pipe4_model_id, torch_dtype=torch.float16, revision="fp16", use_auth_token=True
61
+ )
62
+ else:
63
+ self.pipe4 = sd1_4
64
+
65
+ self.register_modules(pipeline1=self.pipe1, pipeline2=self.pipe2, pipeline3=self.pipe3, pipeline4=self.pipe4)
66
+
67
+ @property
68
+ def layers(self) -> Dict[str, Any]:
69
+ return {k: getattr(self, k) for k in self.config.keys() if not k.startswith("_")}
70
+
71
+ @torch.no_grad()
72
+ def text2img_sd1_1(
73
+ self,
74
+ prompt: Union[str, List[str]],
75
+ height: int = 512,
76
+ width: int = 512,
77
+ num_inference_steps: int = 50,
78
+ guidance_scale: float = 7.5,
79
+ negative_prompt: Optional[Union[str, List[str]]] = None,
80
+ num_images_per_prompt: Optional[int] = 1,
81
+ eta: float = 0.0,
82
+ generator: Optional[torch.Generator] = None,
83
+ latents: Optional[torch.FloatTensor] = None,
84
+ output_type: Optional[str] = "pil",
85
+ return_dict: bool = True,
86
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
87
+ callback_steps: Optional[int] = 1,
88
+ **kwargs,
89
+ ):
90
+ return self.pipe1(
91
+ prompt=prompt,
92
+ height=height,
93
+ width=width,
94
+ num_inference_steps=num_inference_steps,
95
+ guidance_scale=guidance_scale,
96
+ negative_prompt=negative_prompt,
97
+ num_images_per_prompt=num_images_per_prompt,
98
+ eta=eta,
99
+ generator=generator,
100
+ latents=latents,
101
+ output_type=output_type,
102
+ return_dict=return_dict,
103
+ callback=callback,
104
+ callback_steps=callback_steps,
105
+ **kwargs,
106
+ )
107
+
108
+ @torch.no_grad()
109
+ def text2img_sd1_2(
110
+ self,
111
+ prompt: Union[str, List[str]],
112
+ height: int = 512,
113
+ width: int = 512,
114
+ num_inference_steps: int = 50,
115
+ guidance_scale: float = 7.5,
116
+ negative_prompt: Optional[Union[str, List[str]]] = None,
117
+ num_images_per_prompt: Optional[int] = 1,
118
+ eta: float = 0.0,
119
+ generator: Optional[torch.Generator] = None,
120
+ latents: Optional[torch.FloatTensor] = None,
121
+ output_type: Optional[str] = "pil",
122
+ return_dict: bool = True,
123
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
124
+ callback_steps: Optional[int] = 1,
125
+ **kwargs,
126
+ ):
127
+ return self.pipe2(
128
+ prompt=prompt,
129
+ height=height,
130
+ width=width,
131
+ num_inference_steps=num_inference_steps,
132
+ guidance_scale=guidance_scale,
133
+ negative_prompt=negative_prompt,
134
+ num_images_per_prompt=num_images_per_prompt,
135
+ eta=eta,
136
+ generator=generator,
137
+ latents=latents,
138
+ output_type=output_type,
139
+ return_dict=return_dict,
140
+ callback=callback,
141
+ callback_steps=callback_steps,
142
+ **kwargs,
143
+ )
144
+
145
+ @torch.no_grad()
146
+ def text2img_sd1_3(
147
+ self,
148
+ prompt: Union[str, List[str]],
149
+ height: int = 512,
150
+ width: int = 512,
151
+ num_inference_steps: int = 50,
152
+ guidance_scale: float = 7.5,
153
+ negative_prompt: Optional[Union[str, List[str]]] = None,
154
+ num_images_per_prompt: Optional[int] = 1,
155
+ eta: float = 0.0,
156
+ generator: Optional[torch.Generator] = None,
157
+ latents: Optional[torch.FloatTensor] = None,
158
+ output_type: Optional[str] = "pil",
159
+ return_dict: bool = True,
160
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
161
+ callback_steps: Optional[int] = 1,
162
+ **kwargs,
163
+ ):
164
+ return self.pipe3(
165
+ prompt=prompt,
166
+ height=height,
167
+ width=width,
168
+ num_inference_steps=num_inference_steps,
169
+ guidance_scale=guidance_scale,
170
+ negative_prompt=negative_prompt,
171
+ num_images_per_prompt=num_images_per_prompt,
172
+ eta=eta,
173
+ generator=generator,
174
+ latents=latents,
175
+ output_type=output_type,
176
+ return_dict=return_dict,
177
+ callback=callback,
178
+ callback_steps=callback_steps,
179
+ **kwargs,
180
+ )
181
+
182
+ @torch.no_grad()
183
+ def text2img_sd1_4(
184
+ self,
185
+ prompt: Union[str, List[str]],
186
+ height: int = 512,
187
+ width: int = 512,
188
+ num_inference_steps: int = 50,
189
+ guidance_scale: float = 7.5,
190
+ negative_prompt: Optional[Union[str, List[str]]] = None,
191
+ num_images_per_prompt: Optional[int] = 1,
192
+ eta: float = 0.0,
193
+ generator: Optional[torch.Generator] = None,
194
+ latents: Optional[torch.FloatTensor] = None,
195
+ output_type: Optional[str] = "pil",
196
+ return_dict: bool = True,
197
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
198
+ callback_steps: Optional[int] = 1,
199
+ **kwargs,
200
+ ):
201
+ return self.pipe4(
202
+ prompt=prompt,
203
+ height=height,
204
+ width=width,
205
+ num_inference_steps=num_inference_steps,
206
+ guidance_scale=guidance_scale,
207
+ negative_prompt=negative_prompt,
208
+ num_images_per_prompt=num_images_per_prompt,
209
+ eta=eta,
210
+ generator=generator,
211
+ latents=latents,
212
+ output_type=output_type,
213
+ return_dict=return_dict,
214
+ callback=callback,
215
+ callback_steps=callback_steps,
216
+ **kwargs,
217
+ )
218
+
219
+ @torch.no_grad()
220
+ def _call_(
221
+ self,
222
+ prompt: Union[str, List[str]],
223
+ height: int = 512,
224
+ width: int = 512,
225
+ num_inference_steps: int = 50,
226
+ guidance_scale: float = 7.5,
227
+ negative_prompt: Optional[Union[str, List[str]]] = None,
228
+ num_images_per_prompt: Optional[int] = 1,
229
+ eta: float = 0.0,
230
+ generator: Optional[torch.Generator] = None,
231
+ latents: Optional[torch.FloatTensor] = None,
232
+ output_type: Optional[str] = "pil",
233
+ return_dict: bool = True,
234
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
235
+ callback_steps: Optional[int] = 1,
236
+ **kwargs,
237
+ ):
238
+ r"""
239
+ Function invoked when calling the pipeline for generation. This function will generate 4 results as part
240
+ of running all the 4 pipelines for SD1.1-1.4 together in a serial-processing, parallel-invocation fashion.
241
+ Args:
242
+ prompt (`str` or `List[str]`):
243
+ The prompt or prompts to guide the image generation.
244
+ height (`int`, optional, defaults to 512):
245
+ The height in pixels of the generated image.
246
+ width (`int`, optional, defaults to 512):
247
+ The width in pixels of the generated image.
248
+ num_inference_steps (`int`, optional, defaults to 50):
249
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
250
+ expense of slower inference.
251
+ guidance_scale (`float`, optional, defaults to 7.5):
252
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
253
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
254
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
255
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
256
+ usually at the expense of lower image quality.
257
+ eta (`float`, optional, defaults to 0.0):
258
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
259
+ [`schedulers.DDIMScheduler`], will be ignored for others.
260
+ generator (`torch.Generator`, optional):
261
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
262
+ deterministic.
263
+ latents (`torch.FloatTensor`, optional):
264
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
265
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
266
+ tensor will ge generated by sampling using the supplied random `generator`.
267
+ output_type (`str`, optional, defaults to `"pil"`):
268
+ The output format of the generate image. Choose between
269
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
270
+ return_dict (`bool`, optional, defaults to `True`):
271
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
272
+ plain tuple.
273
+ Returns:
274
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
275
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
276
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
277
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
278
+ (nsfw) content, according to the `safety_checker`.
279
+ """
280
+
281
+ device = "cuda" if torch.cuda.is_available() else "cpu"
282
+ self.to(device)
283
+
284
+ # Checks if the height and width are divisible by 8 or not
285
+ if height % 8 != 0 or width % 8 != 0:
286
+ raise ValueError(f"`height` and `width` must be divisible by 8 but are {height} and {width}.")
287
+
288
+ # Get first result from Stable Diffusion Checkpoint v1.1
289
+ res1 = self.text2img_sd1_1(
290
+ prompt=prompt,
291
+ height=height,
292
+ width=width,
293
+ num_inference_steps=num_inference_steps,
294
+ guidance_scale=guidance_scale,
295
+ negative_prompt=negative_prompt,
296
+ num_images_per_prompt=num_images_per_prompt,
297
+ eta=eta,
298
+ generator=generator,
299
+ latents=latents,
300
+ output_type=output_type,
301
+ return_dict=return_dict,
302
+ callback=callback,
303
+ callback_steps=callback_steps,
304
+ **kwargs,
305
+ )
306
+
307
+ # Get first result from Stable Diffusion Checkpoint v1.2
308
+ res2 = self.text2img_sd1_2(
309
+ prompt=prompt,
310
+ height=height,
311
+ width=width,
312
+ num_inference_steps=num_inference_steps,
313
+ guidance_scale=guidance_scale,
314
+ negative_prompt=negative_prompt,
315
+ num_images_per_prompt=num_images_per_prompt,
316
+ eta=eta,
317
+ generator=generator,
318
+ latents=latents,
319
+ output_type=output_type,
320
+ return_dict=return_dict,
321
+ callback=callback,
322
+ callback_steps=callback_steps,
323
+ **kwargs,
324
+ )
325
+
326
+ # Get first result from Stable Diffusion Checkpoint v1.3
327
+ res3 = self.text2img_sd1_3(
328
+ prompt=prompt,
329
+ height=height,
330
+ width=width,
331
+ num_inference_steps=num_inference_steps,
332
+ guidance_scale=guidance_scale,
333
+ negative_prompt=negative_prompt,
334
+ num_images_per_prompt=num_images_per_prompt,
335
+ eta=eta,
336
+ generator=generator,
337
+ latents=latents,
338
+ output_type=output_type,
339
+ return_dict=return_dict,
340
+ callback=callback,
341
+ callback_steps=callback_steps,
342
+ **kwargs,
343
+ )
344
+
345
+ # Get first result from Stable Diffusion Checkpoint v1.4
346
+ res4 = self.text2img_sd1_4(
347
+ prompt=prompt,
348
+ height=height,
349
+ width=width,
350
+ num_inference_steps=num_inference_steps,
351
+ guidance_scale=guidance_scale,
352
+ negative_prompt=negative_prompt,
353
+ num_images_per_prompt=num_images_per_prompt,
354
+ eta=eta,
355
+ generator=generator,
356
+ latents=latents,
357
+ output_type=output_type,
358
+ return_dict=return_dict,
359
+ callback=callback,
360
+ callback_steps=callback_steps,
361
+ **kwargs,
362
+ )
363
+
364
+ # Get all result images into a single list and pass it via StableDiffusionPipelineOutput for final result
365
+ return StableDiffusionPipelineOutput([res1[0], res2[0], res3[0], res4[0]])