Upload pipeline_llava_gen.py with huggingface_hub
Browse files- pipeline_llava_gen.py +287 -0
pipeline_llava_gen.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# ===========================================================================================
|
| 3 |
+
#
|
| 4 |
+
# Copyright (c) Beijing Academy of Artificial Intelligence (BAAI). All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# Author : Fan Zhang
|
| 7 |
+
# Email : zhangfan@baai.ac.cn
|
| 8 |
+
# Institute : Beijing Academy of Artificial Intelligence (BAAI)
|
| 9 |
+
# Create On : 2023-12-19 10:45
|
| 10 |
+
# Last Modified : 2023-12-25 07:59
|
| 11 |
+
# File Name : pipeline_emu2_gen.py
|
| 12 |
+
# Description :
|
| 13 |
+
#
|
| 14 |
+
# ===========================================================================================
|
| 15 |
+
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
from typing import List, Optional
|
| 18 |
+
|
| 19 |
+
from PIL import Image
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
from torchvision import transforms as TF
|
| 23 |
+
from tqdm import tqdm
|
| 24 |
+
import pdb
|
| 25 |
+
|
| 26 |
+
from diffusers import DiffusionPipeline
|
| 27 |
+
from diffusers.utils import BaseOutput
|
| 28 |
+
|
| 29 |
+
from diffusers import UNet2DConditionModel, EulerDiscreteScheduler, AutoencoderKL
|
| 30 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
| 31 |
+
from transformers import CLIPImageProcessor
|
| 32 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 33 |
+
|
| 34 |
+
EVA_IMAGE_SIZE = 448
|
| 35 |
+
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
| 36 |
+
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
|
| 37 |
+
DEFAULT_IMG_PLACEHOLDER = "<image>"
|
| 38 |
+
|
| 39 |
+
from transformers import AutoProcessor
|
| 40 |
+
image_processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct").image_processor
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dataclass
|
| 44 |
+
class EmuVisualGenerationPipelineOutput(BaseOutput):
|
| 45 |
+
image: Image.Image
|
| 46 |
+
nsfw_content_detected: Optional[bool]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class EmuVisualGenerationPipeline(DiffusionPipeline):
|
| 50 |
+
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
tokenizer: AutoTokenizer,
|
| 54 |
+
multimodal_encoder: AutoModelForCausalLM,
|
| 55 |
+
scheduler: EulerDiscreteScheduler,
|
| 56 |
+
unet: UNet2DConditionModel,
|
| 57 |
+
vae: AutoencoderKL,
|
| 58 |
+
feature_extractor: CLIPImageProcessor,
|
| 59 |
+
safety_checker: StableDiffusionSafetyChecker,
|
| 60 |
+
eva_size=EVA_IMAGE_SIZE,
|
| 61 |
+
eva_mean=OPENAI_DATASET_MEAN,
|
| 62 |
+
eva_std=OPENAI_DATASET_STD,
|
| 63 |
+
):
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.register_modules(
|
| 66 |
+
tokenizer=tokenizer,
|
| 67 |
+
multimodal_encoder=multimodal_encoder,
|
| 68 |
+
scheduler=scheduler,
|
| 69 |
+
unet=unet,
|
| 70 |
+
vae=vae,
|
| 71 |
+
feature_extractor=feature_extractor,
|
| 72 |
+
safety_checker=None,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 76 |
+
|
| 77 |
+
self.transform = TF.Compose([
|
| 78 |
+
TF.Resize((eva_size, eva_size), interpolation=TF.InterpolationMode.BICUBIC),
|
| 79 |
+
TF.ToTensor(),
|
| 80 |
+
TF.Normalize(mean=eva_mean, std=eva_std),
|
| 81 |
+
])
|
| 82 |
+
|
| 83 |
+
self.negative_prompt = {}
|
| 84 |
+
|
| 85 |
+
def device(self, module):
|
| 86 |
+
return next(module.parameters()).device
|
| 87 |
+
|
| 88 |
+
def dtype(self, module):
|
| 89 |
+
return next(module.parameters()).dtype
|
| 90 |
+
|
| 91 |
+
@torch.no_grad()
|
| 92 |
+
def __call__(
|
| 93 |
+
self,
|
| 94 |
+
inputs: List[Image.Image | str] | str | Image.Image,
|
| 95 |
+
height: int = 1024,
|
| 96 |
+
width: int = 1024,
|
| 97 |
+
num_inference_steps: int = 50,
|
| 98 |
+
guidance_scale: float = 3.0,
|
| 99 |
+
crop_info: List[int] = [0, 0],
|
| 100 |
+
original_size: List[int] = [1024, 1024],
|
| 101 |
+
):
|
| 102 |
+
if not isinstance(inputs, list):
|
| 103 |
+
inputs = [inputs]
|
| 104 |
+
|
| 105 |
+
# 0. Default height and width to unet
|
| 106 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 107 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 108 |
+
|
| 109 |
+
device = self.device(self.unet)
|
| 110 |
+
dtype = self.dtype(self.unet)
|
| 111 |
+
|
| 112 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 113 |
+
|
| 114 |
+
# 1. Encode input prompt
|
| 115 |
+
prompt_embeds = self._prepare_and_encode_inputs(
|
| 116 |
+
inputs,
|
| 117 |
+
do_classifier_free_guidance,
|
| 118 |
+
).to(dtype).to(device)
|
| 119 |
+
batch_size = prompt_embeds.shape[0] // 2 if do_classifier_free_guidance else prompt_embeds.shape[0]
|
| 120 |
+
|
| 121 |
+
unet_added_conditions = {}
|
| 122 |
+
time_ids = torch.LongTensor(original_size + crop_info + [height, width]).to(device)
|
| 123 |
+
if do_classifier_free_guidance:
|
| 124 |
+
unet_added_conditions["time_ids"] = torch.cat([time_ids, time_ids], dim=0)
|
| 125 |
+
else:
|
| 126 |
+
unet_added_conditions["time_ids"] = time_ids
|
| 127 |
+
unet_added_conditions["text_embeds"] = torch.mean(prompt_embeds, dim=1)
|
| 128 |
+
|
| 129 |
+
# 2. Prepare timesteps
|
| 130 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 131 |
+
timesteps = self.scheduler.timesteps
|
| 132 |
+
|
| 133 |
+
# 3. Prepare latent variables
|
| 134 |
+
shape = (
|
| 135 |
+
batch_size,
|
| 136 |
+
self.unet.config.in_channels,
|
| 137 |
+
height // self.vae_scale_factor,
|
| 138 |
+
width // self.vae_scale_factor,
|
| 139 |
+
)
|
| 140 |
+
latents = torch.randn(shape, device=device, dtype=dtype)
|
| 141 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 142 |
+
|
| 143 |
+
# 4. Denoising loop
|
| 144 |
+
for t in tqdm(timesteps):
|
| 145 |
+
# Expand the latents if doing classifier free guidance: 2B x 4 x H x W
|
| 146 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 147 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 148 |
+
|
| 149 |
+
noise_pred = self.unet(
|
| 150 |
+
latent_model_input,
|
| 151 |
+
t,
|
| 152 |
+
encoder_hidden_states=prompt_embeds,
|
| 153 |
+
added_cond_kwargs=unet_added_conditions,
|
| 154 |
+
).sample
|
| 155 |
+
|
| 156 |
+
# Perform guidance
|
| 157 |
+
if do_classifier_free_guidance:
|
| 158 |
+
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
|
| 159 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 160 |
+
|
| 161 |
+
# Compute the previous noisy sample x_t -> x_t-1
|
| 162 |
+
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
| 163 |
+
|
| 164 |
+
# 5. Post-processing
|
| 165 |
+
images = self.decode_latents(latents)
|
| 166 |
+
# 6. Run safety checker
|
| 167 |
+
# images, has_nsfw_concept = self.run_safety_checker(images)
|
| 168 |
+
|
| 169 |
+
# 7. Convert to PIL
|
| 170 |
+
images = self.numpy_to_pil(images)
|
| 171 |
+
|
| 172 |
+
# return EmuVisualGenerationPipelineOutput(
|
| 173 |
+
# image=images[0],
|
| 174 |
+
# nsfw_content_detected=None if has_nsfw_concept is None else has_nsfw_concept[0],
|
| 175 |
+
# )
|
| 176 |
+
|
| 177 |
+
return EmuVisualGenerationPipelineOutput(
|
| 178 |
+
image=images[0],
|
| 179 |
+
nsfw_content_detected=None
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
def _prepare_and_encode_inputs(
|
| 183 |
+
self,
|
| 184 |
+
inputs: List[str | Image.Image],
|
| 185 |
+
do_classifier_free_guidance: bool = False,
|
| 186 |
+
placeholder: str = DEFAULT_IMG_PLACEHOLDER,
|
| 187 |
+
):
|
| 188 |
+
# pdb.set_trace()
|
| 189 |
+
device = self.device(self.multimodal_encoder.model)
|
| 190 |
+
dtype = self.dtype(self.multimodal_encoder.model)
|
| 191 |
+
|
| 192 |
+
has_image, has_text = False, False
|
| 193 |
+
text_prompt, image_prompt, image_grid_thw = "", [], []
|
| 194 |
+
for x in inputs:
|
| 195 |
+
if isinstance(x, str):
|
| 196 |
+
has_text = True
|
| 197 |
+
text_prompt += x
|
| 198 |
+
else:
|
| 199 |
+
has_image = True
|
| 200 |
+
text_prompt = text_prompt.replace(
|
| 201 |
+
"<image>",
|
| 202 |
+
"<|vision_start|>" + "<|image_pad|>" * 256 + "<|vision_end|>"
|
| 203 |
+
)
|
| 204 |
+
resized_images = x.resize((448, 448))
|
| 205 |
+
image_inputs = image_processor(resized_images, return_tensors="pt")
|
| 206 |
+
image_prompt.append(image_inputs.pixel_values)
|
| 207 |
+
image_grid_thw.append(image_inputs.image_grid_thw)
|
| 208 |
+
|
| 209 |
+
if len(image_prompt) == 0:
|
| 210 |
+
image_prompt = None
|
| 211 |
+
image_grid_thw = None
|
| 212 |
+
else:
|
| 213 |
+
image_prompt = torch.cat(image_prompt, dim=0)
|
| 214 |
+
image_grid_thw = torch.cat(image_grid_thw, dim=0)
|
| 215 |
+
# breakpoint()
|
| 216 |
+
if has_image and not has_text:
|
| 217 |
+
prompt = self.multimodal_encoder.model.encode_image(image=image_prompt)
|
| 218 |
+
if do_classifier_free_guidance:
|
| 219 |
+
key = "[NULL_IMAGE]"
|
| 220 |
+
if key not in self.negative_prompt:
|
| 221 |
+
negative_image = torch.zeros_like(image_prompt)
|
| 222 |
+
self.negative_prompt[key] = self.multimodal_encoder.model.encode_image(image=negative_image)
|
| 223 |
+
prompt = torch.cat([prompt, self.negative_prompt[key]], dim=0)
|
| 224 |
+
elif has_text and not has_image:
|
| 225 |
+
|
| 226 |
+
prompt = self.multimodal_encoder.generate_image(
|
| 227 |
+
text=[text_prompt], tokenizer=self.tokenizer
|
| 228 |
+
)
|
| 229 |
+
if do_classifier_free_guidance:
|
| 230 |
+
key = ""
|
| 231 |
+
if key not in self.negative_prompt:
|
| 232 |
+
self.negative_prompt[key] = self.multimodal_encoder.generate_image(
|
| 233 |
+
text=[" "],
|
| 234 |
+
tokenizer=self.tokenizer
|
| 235 |
+
)
|
| 236 |
+
prompt = torch.cat([prompt, self.negative_prompt[key]], dim=0)
|
| 237 |
+
elif has_text and has_image:
|
| 238 |
+
prompt = self.multimodal_encoder.generate_image(
|
| 239 |
+
text=[text_prompt],
|
| 240 |
+
pixel_values=image_prompt.cuda(),
|
| 241 |
+
image_grid_thw=image_grid_thw.cuda(),
|
| 242 |
+
tokenizer=self.tokenizer
|
| 243 |
+
)
|
| 244 |
+
if do_classifier_free_guidance:
|
| 245 |
+
key = ""
|
| 246 |
+
if key not in self.negative_prompt:
|
| 247 |
+
self.negative_prompt[key] = self.multimodal_encoder.generate_image(
|
| 248 |
+
text=[" "],
|
| 249 |
+
tokenizer=self.tokenizer
|
| 250 |
+
)
|
| 251 |
+
prompt = torch.cat([prompt, self.negative_prompt[key]], dim=0)
|
| 252 |
+
return prompt
|
| 253 |
+
|
| 254 |
+
def decode_latents(self, latents: torch.Tensor) -> np.ndarray:
|
| 255 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
| 256 |
+
image = self.vae.decode(latents).sample
|
| 257 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 258 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
| 259 |
+
return image
|
| 260 |
+
|
| 261 |
+
def numpy_to_pil(self, images: np.ndarray) -> List[Image.Image]:
|
| 262 |
+
"""
|
| 263 |
+
Convert a numpy image or a batch of images to a PIL image.
|
| 264 |
+
"""
|
| 265 |
+
if images.ndim == 3:
|
| 266 |
+
images = images[None, ...]
|
| 267 |
+
images = (images * 255).round().astype("uint8")
|
| 268 |
+
if images.shape[-1] == 1:
|
| 269 |
+
# Special case for grayscale (single channel) images.
|
| 270 |
+
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
| 271 |
+
else:
|
| 272 |
+
pil_images = [Image.fromarray(image) for image in images]
|
| 273 |
+
return pil_images
|
| 274 |
+
|
| 275 |
+
def run_safety_checker(self, images: np.ndarray):
|
| 276 |
+
if self.safety_checker is not None:
|
| 277 |
+
device = self.device(self.safety_checker)
|
| 278 |
+
dtype = self.dtype(self.safety_checker)
|
| 279 |
+
safety_checker_input = self.feature_extractor(
|
| 280 |
+
self.numpy_to_pil(images), return_tensors="pt"
|
| 281 |
+
).to(device)
|
| 282 |
+
images, has_nsfw_concept = self.safety_checker(
|
| 283 |
+
images=images, clip_input=safety_checker_input.pixel_values.to(dtype)
|
| 284 |
+
)
|
| 285 |
+
else:
|
| 286 |
+
has_nsfw_concept = None
|
| 287 |
+
return images, has_nsfw_concept
|