orrzohar commited on
Commit
918f558
·
verified ·
1 Parent(s): 9f0b1a2

Upload pipeline_llava_gen.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pipeline_llava_gen.py +287 -0
pipeline_llava_gen.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # ===========================================================================================
3
+ #
4
+ # Copyright (c) Beijing Academy of Artificial Intelligence (BAAI). All rights reserved.
5
+ #
6
+ # Author : Fan Zhang
7
+ # Email : zhangfan@baai.ac.cn
8
+ # Institute : Beijing Academy of Artificial Intelligence (BAAI)
9
+ # Create On : 2023-12-19 10:45
10
+ # Last Modified : 2023-12-25 07:59
11
+ # File Name : pipeline_emu2_gen.py
12
+ # Description :
13
+ #
14
+ # ===========================================================================================
15
+
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional
18
+
19
+ from PIL import Image
20
+ import numpy as np
21
+ import torch
22
+ from torchvision import transforms as TF
23
+ from tqdm import tqdm
24
+ import pdb
25
+
26
+ from diffusers import DiffusionPipeline
27
+ from diffusers.utils import BaseOutput
28
+
29
+ from diffusers import UNet2DConditionModel, EulerDiscreteScheduler, AutoencoderKL
30
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
31
+ from transformers import CLIPImageProcessor
32
+ from transformers import AutoModelForCausalLM, AutoTokenizer
33
+
34
+ EVA_IMAGE_SIZE = 448
35
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
36
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
37
+ DEFAULT_IMG_PLACEHOLDER = "<image>"
38
+
39
+ from transformers import AutoProcessor
40
+ image_processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct").image_processor
41
+
42
+
43
+ @dataclass
44
+ class EmuVisualGenerationPipelineOutput(BaseOutput):
45
+ image: Image.Image
46
+ nsfw_content_detected: Optional[bool]
47
+
48
+
49
+ class EmuVisualGenerationPipeline(DiffusionPipeline):
50
+
51
+ def __init__(
52
+ self,
53
+ tokenizer: AutoTokenizer,
54
+ multimodal_encoder: AutoModelForCausalLM,
55
+ scheduler: EulerDiscreteScheduler,
56
+ unet: UNet2DConditionModel,
57
+ vae: AutoencoderKL,
58
+ feature_extractor: CLIPImageProcessor,
59
+ safety_checker: StableDiffusionSafetyChecker,
60
+ eva_size=EVA_IMAGE_SIZE,
61
+ eva_mean=OPENAI_DATASET_MEAN,
62
+ eva_std=OPENAI_DATASET_STD,
63
+ ):
64
+ super().__init__()
65
+ self.register_modules(
66
+ tokenizer=tokenizer,
67
+ multimodal_encoder=multimodal_encoder,
68
+ scheduler=scheduler,
69
+ unet=unet,
70
+ vae=vae,
71
+ feature_extractor=feature_extractor,
72
+ safety_checker=None,
73
+ )
74
+
75
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
76
+
77
+ self.transform = TF.Compose([
78
+ TF.Resize((eva_size, eva_size), interpolation=TF.InterpolationMode.BICUBIC),
79
+ TF.ToTensor(),
80
+ TF.Normalize(mean=eva_mean, std=eva_std),
81
+ ])
82
+
83
+ self.negative_prompt = {}
84
+
85
+ def device(self, module):
86
+ return next(module.parameters()).device
87
+
88
+ def dtype(self, module):
89
+ return next(module.parameters()).dtype
90
+
91
+ @torch.no_grad()
92
+ def __call__(
93
+ self,
94
+ inputs: List[Image.Image | str] | str | Image.Image,
95
+ height: int = 1024,
96
+ width: int = 1024,
97
+ num_inference_steps: int = 50,
98
+ guidance_scale: float = 3.0,
99
+ crop_info: List[int] = [0, 0],
100
+ original_size: List[int] = [1024, 1024],
101
+ ):
102
+ if not isinstance(inputs, list):
103
+ inputs = [inputs]
104
+
105
+ # 0. Default height and width to unet
106
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
107
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
108
+
109
+ device = self.device(self.unet)
110
+ dtype = self.dtype(self.unet)
111
+
112
+ do_classifier_free_guidance = guidance_scale > 1.0
113
+
114
+ # 1. Encode input prompt
115
+ prompt_embeds = self._prepare_and_encode_inputs(
116
+ inputs,
117
+ do_classifier_free_guidance,
118
+ ).to(dtype).to(device)
119
+ batch_size = prompt_embeds.shape[0] // 2 if do_classifier_free_guidance else prompt_embeds.shape[0]
120
+
121
+ unet_added_conditions = {}
122
+ time_ids = torch.LongTensor(original_size + crop_info + [height, width]).to(device)
123
+ if do_classifier_free_guidance:
124
+ unet_added_conditions["time_ids"] = torch.cat([time_ids, time_ids], dim=0)
125
+ else:
126
+ unet_added_conditions["time_ids"] = time_ids
127
+ unet_added_conditions["text_embeds"] = torch.mean(prompt_embeds, dim=1)
128
+
129
+ # 2. Prepare timesteps
130
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
131
+ timesteps = self.scheduler.timesteps
132
+
133
+ # 3. Prepare latent variables
134
+ shape = (
135
+ batch_size,
136
+ self.unet.config.in_channels,
137
+ height // self.vae_scale_factor,
138
+ width // self.vae_scale_factor,
139
+ )
140
+ latents = torch.randn(shape, device=device, dtype=dtype)
141
+ latents = latents * self.scheduler.init_noise_sigma
142
+
143
+ # 4. Denoising loop
144
+ for t in tqdm(timesteps):
145
+ # Expand the latents if doing classifier free guidance: 2B x 4 x H x W
146
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
147
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
148
+
149
+ noise_pred = self.unet(
150
+ latent_model_input,
151
+ t,
152
+ encoder_hidden_states=prompt_embeds,
153
+ added_cond_kwargs=unet_added_conditions,
154
+ ).sample
155
+
156
+ # Perform guidance
157
+ if do_classifier_free_guidance:
158
+ noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
159
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
160
+
161
+ # Compute the previous noisy sample x_t -> x_t-1
162
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
163
+
164
+ # 5. Post-processing
165
+ images = self.decode_latents(latents)
166
+ # 6. Run safety checker
167
+ # images, has_nsfw_concept = self.run_safety_checker(images)
168
+
169
+ # 7. Convert to PIL
170
+ images = self.numpy_to_pil(images)
171
+
172
+ # return EmuVisualGenerationPipelineOutput(
173
+ # image=images[0],
174
+ # nsfw_content_detected=None if has_nsfw_concept is None else has_nsfw_concept[0],
175
+ # )
176
+
177
+ return EmuVisualGenerationPipelineOutput(
178
+ image=images[0],
179
+ nsfw_content_detected=None
180
+ )
181
+
182
+ def _prepare_and_encode_inputs(
183
+ self,
184
+ inputs: List[str | Image.Image],
185
+ do_classifier_free_guidance: bool = False,
186
+ placeholder: str = DEFAULT_IMG_PLACEHOLDER,
187
+ ):
188
+ # pdb.set_trace()
189
+ device = self.device(self.multimodal_encoder.model)
190
+ dtype = self.dtype(self.multimodal_encoder.model)
191
+
192
+ has_image, has_text = False, False
193
+ text_prompt, image_prompt, image_grid_thw = "", [], []
194
+ for x in inputs:
195
+ if isinstance(x, str):
196
+ has_text = True
197
+ text_prompt += x
198
+ else:
199
+ has_image = True
200
+ text_prompt = text_prompt.replace(
201
+ "<image>",
202
+ "<|vision_start|>" + "<|image_pad|>" * 256 + "<|vision_end|>"
203
+ )
204
+ resized_images = x.resize((448, 448))
205
+ image_inputs = image_processor(resized_images, return_tensors="pt")
206
+ image_prompt.append(image_inputs.pixel_values)
207
+ image_grid_thw.append(image_inputs.image_grid_thw)
208
+
209
+ if len(image_prompt) == 0:
210
+ image_prompt = None
211
+ image_grid_thw = None
212
+ else:
213
+ image_prompt = torch.cat(image_prompt, dim=0)
214
+ image_grid_thw = torch.cat(image_grid_thw, dim=0)
215
+ # breakpoint()
216
+ if has_image and not has_text:
217
+ prompt = self.multimodal_encoder.model.encode_image(image=image_prompt)
218
+ if do_classifier_free_guidance:
219
+ key = "[NULL_IMAGE]"
220
+ if key not in self.negative_prompt:
221
+ negative_image = torch.zeros_like(image_prompt)
222
+ self.negative_prompt[key] = self.multimodal_encoder.model.encode_image(image=negative_image)
223
+ prompt = torch.cat([prompt, self.negative_prompt[key]], dim=0)
224
+ elif has_text and not has_image:
225
+
226
+ prompt = self.multimodal_encoder.generate_image(
227
+ text=[text_prompt], tokenizer=self.tokenizer
228
+ )
229
+ if do_classifier_free_guidance:
230
+ key = ""
231
+ if key not in self.negative_prompt:
232
+ self.negative_prompt[key] = self.multimodal_encoder.generate_image(
233
+ text=[" "],
234
+ tokenizer=self.tokenizer
235
+ )
236
+ prompt = torch.cat([prompt, self.negative_prompt[key]], dim=0)
237
+ elif has_text and has_image:
238
+ prompt = self.multimodal_encoder.generate_image(
239
+ text=[text_prompt],
240
+ pixel_values=image_prompt.cuda(),
241
+ image_grid_thw=image_grid_thw.cuda(),
242
+ tokenizer=self.tokenizer
243
+ )
244
+ if do_classifier_free_guidance:
245
+ key = ""
246
+ if key not in self.negative_prompt:
247
+ self.negative_prompt[key] = self.multimodal_encoder.generate_image(
248
+ text=[" "],
249
+ tokenizer=self.tokenizer
250
+ )
251
+ prompt = torch.cat([prompt, self.negative_prompt[key]], dim=0)
252
+ return prompt
253
+
254
+ def decode_latents(self, latents: torch.Tensor) -> np.ndarray:
255
+ latents = 1 / self.vae.config.scaling_factor * latents
256
+ image = self.vae.decode(latents).sample
257
+ image = (image / 2 + 0.5).clamp(0, 1)
258
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
259
+ return image
260
+
261
+ def numpy_to_pil(self, images: np.ndarray) -> List[Image.Image]:
262
+ """
263
+ Convert a numpy image or a batch of images to a PIL image.
264
+ """
265
+ if images.ndim == 3:
266
+ images = images[None, ...]
267
+ images = (images * 255).round().astype("uint8")
268
+ if images.shape[-1] == 1:
269
+ # Special case for grayscale (single channel) images.
270
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
271
+ else:
272
+ pil_images = [Image.fromarray(image) for image in images]
273
+ return pil_images
274
+
275
+ def run_safety_checker(self, images: np.ndarray):
276
+ if self.safety_checker is not None:
277
+ device = self.device(self.safety_checker)
278
+ dtype = self.dtype(self.safety_checker)
279
+ safety_checker_input = self.feature_extractor(
280
+ self.numpy_to_pil(images), return_tensors="pt"
281
+ ).to(device)
282
+ images, has_nsfw_concept = self.safety_checker(
283
+ images=images, clip_input=safety_checker_input.pixel_values.to(dtype)
284
+ )
285
+ else:
286
+ has_nsfw_concept = None
287
+ return images, has_nsfw_concept