TorchRik commited on
Commit
f1c2d7c
·
verified ·
1 Parent(s): 3ed79e7

Upload combined_stable_diffusion.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. combined_stable_diffusion.py +352 -0
combined_stable_diffusion.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+
4
+ import torch
5
+ from diffusers import DiffusionPipeline
6
+ from diffusers.image_processor import VaeImageProcessor
7
+ from huggingface_hub import PyTorchModelHubMixin
8
+ from PIL import Image
9
+
10
+
11
+ class CombinedStableDiffusion(
12
+ DiffusionPipeline,
13
+ PyTorchModelHubMixin
14
+ ):
15
+ """
16
+ A Stable Diffusion model wrapper that provides functionality for text-to-image synthesis,
17
+ noise scheduling, latent space manipulation, and image decoding.
18
+ """
19
+ def __init__(
20
+ self,
21
+ original_unet: torch.nn.Module,
22
+ fine_tuned_unet: torch.nn.Module,
23
+ scheduler,
24
+ vae: torch.nn.Module,
25
+ tokenizer=None,
26
+ text_encoder=None,
27
+ ) -> None:
28
+
29
+ super().__init__()
30
+
31
+ self.register_modules(
32
+ tokenizer=tokenizer,
33
+ original_unet=original_unet,
34
+ fine_tuned_unet=fine_tuned_unet,
35
+ scheduler=scheduler,
36
+ vae=vae,
37
+ text_encoder=text_encoder,
38
+ )
39
+
40
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
41
+ self.image_processor = VaeImageProcessor(
42
+ vae_scale_factor=self.vae_scale_factor
43
+ )
44
+
45
+ def _get_negative_prompts(self, batch_size: int) -> torch.Tensor:
46
+ return self.tokenizer(
47
+ [""] * batch_size,
48
+ max_length=self.tokenizer.model_max_length,
49
+ padding="max_length",
50
+ truncation=True,
51
+ return_tensors="pt",
52
+ ).input_ids
53
+
54
+ def _get_encoder_hidden_states(
55
+ self, tokenized_prompts: torch.Tensor, do_classifier_free_guidance: bool = False
56
+ ) -> torch.Tensor:
57
+ if do_classifier_free_guidance:
58
+ tokenized_prompts = torch.cat(
59
+ [
60
+ self._get_negative_prompts(tokenized_prompts.shape[0]).to(
61
+ tokenized_prompts.device
62
+ ),
63
+ tokenized_prompts,
64
+ ]
65
+ )
66
+
67
+ return self.text_encoder(tokenized_prompts)[0]
68
+
69
+ def _get_unet_prediction(
70
+ self,
71
+ latent_model_input: torch.Tensor,
72
+ timestep: int,
73
+ encoder_hidden_states: torch.Tensor,
74
+ ) -> torch.Tensor:
75
+ """
76
+ Return unet noise prediction
77
+
78
+ Args:
79
+ latent_model_input (torch.Tensor): Unet latents input
80
+ timestep (int): noise scheduler timestep
81
+ encoder_hidden_states (torch.Tensor): Text encoder hidden states
82
+
83
+ Returns:
84
+ torch.Tensor: noise prediction
85
+ """
86
+ unet = self.original_unet if self._use_original_unet else self.fine_tuned_unet
87
+
88
+ return unet(
89
+ latent_model_input,
90
+ timestep=timestep,
91
+ encoder_hidden_states=encoder_hidden_states,
92
+ ).sample
93
+
94
+ def get_noise_prediction(
95
+ self,
96
+ latents: torch.Tensor,
97
+ timestep_index: int,
98
+ encoder_hidden_states: torch.Tensor,
99
+ do_classifier_free_guidance: bool = False,
100
+ detach_main_path: bool = False,
101
+ ):
102
+ """
103
+ Return noise prediction
104
+
105
+ Args:
106
+ latents (torch.Tensor): Image latents
107
+ timestep_index (int): noise scheduler timestep index
108
+ encoder_hidden_states (torch.Tensor): Text encoder hidden states
109
+ do_classifier_free_guidance (bool) Whether to do classifier free guidance
110
+ detach_main_path (bool): Detach gradient
111
+
112
+ Returns:
113
+ torch.Tensor: noise prediction
114
+ """
115
+ timestep = self.scheduler.timesteps[timestep_index]
116
+
117
+ latent_model_input = self.scheduler.scale_model_input(
118
+ sample=torch.cat([latents] * 2) if do_classifier_free_guidance else latents,
119
+ timestep=timestep,
120
+ )
121
+
122
+ noise_pred = self._get_unet_prediction(
123
+ latent_model_input=latent_model_input,
124
+ timestep=timestep,
125
+ encoder_hidden_states=encoder_hidden_states,
126
+ )
127
+
128
+ if do_classifier_free_guidance:
129
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
130
+ if detach_main_path:
131
+ noise_pred_text = noise_pred_text.detach()
132
+
133
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
134
+ noise_pred_text - noise_pred_uncond
135
+ )
136
+ return noise_pred
137
+
138
+ def sample_next_latents(
139
+ self,
140
+ latents: torch.Tensor,
141
+ timestep_index: int,
142
+ noise_pred: torch.Tensor,
143
+ return_pred_original: bool = False,
144
+ ) -> torch.Tensor:
145
+ """
146
+ Return next latents prediction
147
+
148
+ Args:
149
+ latents (torch.Tensor): Image latents
150
+ timestep_index (int): noise scheduler timestep index
151
+ noise_pred (torch.Tensor): noise prediction
152
+ return_pred_original (bool) Whether to sample original sample
153
+
154
+ Returns:
155
+ torch.Tensor: latent prediction
156
+ """
157
+ timestep = self.scheduler.timesteps[timestep_index]
158
+ sample = self.scheduler.step(
159
+ model_output=noise_pred, timestep=timestep, sample=latents
160
+ )
161
+ return (
162
+ sample.pred_original_sample if return_pred_original else sample.prev_sample
163
+ )
164
+
165
+ def predict_next_latents(
166
+ self,
167
+ latents: torch.Tensor,
168
+ timestep_index: int,
169
+ encoder_hidden_states: torch.Tensor,
170
+ return_pred_original: bool = False,
171
+ do_classifier_free_guidance: bool = False,
172
+ detach_main_path: bool = False,
173
+ ) -> tuple[torch.Tensor, torch.Tensor]:
174
+ """
175
+ Predicts the next latent states during the diffusion process.
176
+
177
+ Args:
178
+ latents (torch.Tensor): Current latent states.
179
+ timestep_index (int): Index of the current timestep.
180
+ encoder_hidden_states (torch.Tensor): Encoder hidden states from the text encoder.
181
+ return_pred_original (bool): Whether to return the predicted original sample.
182
+ do_classifier_free_guidance (bool) Whether to do classifier free guidance
183
+ detach_main_path (bool): Detach gradient
184
+
185
+ Returns:
186
+ tuple: Next latents and predicted noise tensor.
187
+ """
188
+
189
+ noise_pred = self.get_noise_prediction(
190
+ latents=latents,
191
+ timestep_index=timestep_index,
192
+ encoder_hidden_states=encoder_hidden_states,
193
+ do_classifier_free_guidance=do_classifier_free_guidance,
194
+ detach_main_path=detach_main_path,
195
+ )
196
+
197
+ latents = self.sample_next_latents(
198
+ latents=latents,
199
+ noise_pred=noise_pred,
200
+ timestep_index=timestep_index,
201
+ return_pred_original=return_pred_original,
202
+ )
203
+
204
+ return latents, noise_pred
205
+
206
+ def get_latents(self, batch_size: int, device: torch.device) -> torch.Tensor:
207
+ latent_resolution = int(self.resolution) // self.vae_scale_factor
208
+ return torch.randn(
209
+ (
210
+ batch_size,
211
+ self.original_unet.config.in_channels,
212
+ latent_resolution,
213
+ latent_resolution,
214
+ ),
215
+ device=device,
216
+ )
217
+
218
+ def do_k_diffusion_steps(
219
+ self,
220
+ start_timestep_index: int,
221
+ end_timestep_index: int,
222
+ latents: torch.Tensor,
223
+ encoder_hidden_states: torch.Tensor,
224
+ return_pred_original: bool = False,
225
+ do_classifier_free_guidance: bool = False,
226
+ detach_main_path: bool = False,
227
+ ) -> tuple[torch.Tensor, torch.Tensor]:
228
+ """
229
+ Performs multiple diffusion steps between specified timesteps.
230
+
231
+ Args:
232
+ start_timestep_index (int): Starting timestep index.
233
+ end_timestep_index (int): Ending timestep index.
234
+ latents (torch.Tensor): Initial latents.
235
+ encoder_hidden_states (torch.Tensor): Encoder hidden states.
236
+ return_pred_original (bool): Whether to return the predicted original sample.
237
+ do_classifier_free_guidance (bool) Whether to do classifier free guidance
238
+ detach_main_path (bool): Detach gradient
239
+
240
+ Returns:
241
+ tuple: Resulting latents and encoder hidden states.
242
+ """
243
+ assert start_timestep_index <= end_timestep_index
244
+
245
+ for timestep_index in range(start_timestep_index, end_timestep_index - 1):
246
+ latents, _ = self.predict_next_latents(
247
+ latents=latents,
248
+ timestep_index=timestep_index,
249
+ encoder_hidden_states=encoder_hidden_states,
250
+ return_pred_original=False,
251
+ do_classifier_free_guidance=do_classifier_free_guidance,
252
+ detach_main_path=detach_main_path,
253
+ )
254
+ res, _ = self.predict_next_latents(
255
+ latents=latents,
256
+ timestep_index=end_timestep_index - 1,
257
+ encoder_hidden_states=encoder_hidden_states,
258
+ return_pred_original=return_pred_original,
259
+ do_classifier_free_guidance=do_classifier_free_guidance,
260
+ )
261
+ return res, encoder_hidden_states
262
+
263
+ def get_pil_image(self, raw_images: torch.Tensor) -> list[Image]:
264
+ do_denormalize = [True] * raw_images.shape[0]
265
+ images = self.inference_image_processor.postprocess(
266
+ raw_images, output_type="pil", do_denormalize=do_denormalize
267
+ )
268
+ return images
269
+
270
+ def get_reward_image(self, raw_images: torch.Tensor) -> torch.Tensor:
271
+ reward_images = (raw_images / 2 + 0.5).clamp(0, 1)
272
+
273
+ if self.use_image_shifting:
274
+ self._shift_tensor_batch(
275
+ reward_images,
276
+ dx=random.randint(0, math.ceil(self.resolution / 224)),
277
+ dy=random.randint(0, math.ceil(self.resolution / 224)),
278
+ )
279
+
280
+ return self.reward_image_processor(reward_images)
281
+
282
+ @torch.no_grad()
283
+ def __call__(
284
+ self,
285
+ prompt: str | list[str],
286
+ num_inference_steps=40,
287
+ original_unet_steps=30,
288
+ resolution=512,
289
+ guidance_scale=7.5,
290
+ generator=None
291
+ ):
292
+ self.guidance_scale = guidance_scale
293
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
294
+
295
+ tokenized_prompts = self.tokenizer(
296
+ prompt,
297
+ return_tensors="pt",
298
+ padding="max_length",
299
+ max_length=self.tokenizer.model_max_length,
300
+ truncation=True
301
+ ).input_ids.to(self.device)
302
+ original_encoder_hidden_states = self._get_encoder_hidden_states(
303
+ tokenized_prompts=tokenized_prompts,
304
+ do_classifier_free_guidance=True
305
+ )
306
+ fine_tuned_encoder_hidden_states = self._get_encoder_hidden_states(
307
+ tokenized_prompts=tokenized_prompts,
308
+ do_classifier_free_guidance=False
309
+ )
310
+
311
+ latent_resolution = int(resolution) // self.vae_scale_factor
312
+ latents = torch.randn(
313
+ (
314
+ batch_size,
315
+ self.original_unet.config.in_channels,
316
+ latent_resolution,
317
+ latent_resolution,
318
+ ),
319
+ device=self.device,
320
+ )
321
+
322
+ self.scheduler.set_timesteps(
323
+ num_inference_steps,
324
+ device=self.device
325
+ )
326
+
327
+ self._use_original_unet = True
328
+ latents, _ = self.do_k_diffusion_steps(
329
+ start_timestep_index=0,
330
+ end_timestep_index=original_unet_steps,
331
+ latents=latents,
332
+ encoder_hidden_states=original_encoder_hidden_states,
333
+ return_pred_original=False,
334
+ do_classifier_free_guidance=True,
335
+ )
336
+
337
+ self._use_original_unet = False
338
+ pred_original_sample, _ = self.do_k_diffusion_steps(
339
+ start_timestep_index=original_unet_steps,
340
+ end_timestep_index=num_inference_steps,
341
+ latents=latents,
342
+ encoder_hidden_states=fine_tuned_encoder_hidden_states,
343
+ return_pred_original=False,
344
+ do_classifier_free_guidance=False,
345
+ )
346
+
347
+ pred_original_sample /= self.vae.config.scaling_factor
348
+
349
+ image = self.vae.decode(pred_original_sample).sample
350
+ return self.image_processor.postprocess(
351
+ image, output_type='pil'
352
+ )