| | import re |
| | import copy |
| | from typing import Literal |
| |
|
| | from PIL import Image |
| | from tqdm.auto import tqdm |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torchvision.transforms as transforms |
| |
|
| | from transformers import AutoTokenizer |
| | from transformers.cache_utils import Cache, StaticCache |
| |
|
| | from models.nextstep_model import NextStep |
| | from vae.nextstep_ae import AutoencoderKL |
| | from utils.image_utils import to_pil |
| | from utils.model_utils import layer_norm |
| | from utils.compile_utils import compile_manager |
| | from utils.misc import set_seed |
| |
|
| | DEFAULT_IMAGE_AREA_TOKEN = "<|image_area|>" |
| |
|
| |
|
| | def hw2str(h: int, w: int) -> str: |
| | return f"{h}*{w}" |
| |
|
| |
|
| | class NextStepPipeline: |
| | def __init__( |
| | self, |
| | model_name_or_path: str | None = None, |
| | vae_name_or_path: str | None = None, |
| | tokenizer: AutoTokenizer | None = None, |
| | model: nn.Module | None = None, |
| | vae: AutoencoderKL | None = None, |
| | ): |
| | if model is not None: |
| | self.tokenizer = copy.deepcopy(tokenizer) |
| | self.tokenizer.padding_side = "left" |
| | self.model = model |
| |
|
| | elif model_name_or_path is not None: |
| | self.tokenizer = AutoTokenizer.from_pretrained( |
| | model_name_or_path, |
| | local_files_only=True, |
| | model_max_length=4096, |
| | padding_side="left", |
| | use_fast=True, |
| | ) |
| | self.model: NextStep = NextStep.from_pretrained(model_name_or_path, local_files_only=True) |
| |
|
| | else: |
| | raise ValueError("model or model_name_or_path is required") |
| |
|
| | self.tokenizer.add_eos_token = False |
| |
|
| | if vae_name_or_path is None: |
| | vae_name_or_path = getattr(self.model.config, "vae_name_or_path", None) |
| |
|
| | if vae is not None: |
| | self.vae = vae |
| | elif vae_name_or_path is not None: |
| | self.vae = AutoencoderKL.from_pretrained(vae_name_or_path) |
| | else: |
| | raise ValueError("vae or vae_name_or_path is required") |
| |
|
| | self.model.eval() |
| | self.vae.eval() |
| |
|
| | vae_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) |
| | self.down_factor = vae_factor * self.model.config.latent_patch_size |
| |
|
| | self.shift_factor = getattr(self.vae.config, "shift_factor", 0.0) |
| | self.scaling_factor = getattr(self.vae.config, "scaling_factor", 1.0) |
| |
|
| | self.boi = self.model.config.boi |
| | self.eoi = self.model.config.eoi |
| | self.image_placeholder_id = self.model.config.image_placeholder_id |
| |
|
| | self.pil2tensor = transforms.Compose( |
| | [ |
| | transforms.ToTensor(), |
| | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), |
| | ] |
| | ) |
| |
|
| | self.__device = self.model.device |
| | self.__dtype = self.model.dtype |
| |
|
| | @property |
| | def device(self): |
| | return self.__device |
| |
|
| | @property |
| | def device_type(self): |
| | if isinstance(self.__device, str): |
| | return self.__device |
| | return self.__device.type |
| |
|
| | @property |
| | def dtype(self): |
| | return self.__dtype |
| |
|
| | def to(self, device: str | None = None, dtype: torch.dtype | None = None): |
| | if device is not None: |
| | self.__device = device |
| | if dtype is not None: |
| | self.__dtype = dtype |
| | self.model.to(self.__device, dtype=self.__dtype) |
| | self.vae.to(self.__device, dtype=self.__dtype) |
| | return self |
| |
|
| | def _image_str(self, hw: tuple[int, int] = (256, 256)): |
| | latent_hw = (hw[0] // self.down_factor, hw[1] // self.down_factor) |
| | image_ids = [self.boi] + [self.image_placeholder_id] * (latent_hw[0] * latent_hw[1]) + [self.eoi] |
| | image_str = DEFAULT_IMAGE_AREA_TOKEN + hw2str(*latent_hw) + self.tokenizer.decode(image_ids) |
| | return image_str |
| |
|
| | def _check_input( |
| | self, captions: str | list[str], images: Image.Image | list[Image.Image] | None |
| | ) -> tuple[list[str], list[Image.Image] | None]: |
| | if not isinstance(captions, list): |
| | captions = [captions] |
| |
|
| | if images is not None: |
| | if not isinstance(images, list): |
| | images = [images] |
| |
|
| | |
| | image_token_count = 0 |
| | for caption in captions: |
| | num_image_token = len(re.findall(r"<image>", caption)) |
| | assert num_image_token == 1, f"Caption `{caption}` has {num_image_token} image tokens, but only 1 is allowed." |
| | image_token_count += num_image_token |
| | if image_token_count != len(images): |
| | raise ValueError( |
| | f"Number of images ({len(images)}) does not match number of image tokens ({image_token_count}).\n" |
| | f"Captions: {captions}" |
| | ) |
| |
|
| | hws = [(image.size[1], image.size[0]) for image in images] |
| |
|
| | |
| | processed_captions = [] |
| | image_idx = 0 |
| | for caption in captions: |
| | |
| | processed_caption = caption |
| | num_image_tokens = processed_caption.count("<image>") |
| |
|
| | |
| | for _ in range(num_image_tokens): |
| | processed_caption = processed_caption.replace("<image>", self._image_str(hws[image_idx]), 1) |
| | image_idx += 1 |
| |
|
| | processed_captions.append(processed_caption) |
| |
|
| | captions = processed_captions |
| | return captions, images |
| |
|
| | def _build_captions( |
| | self, |
| | captions: str | list[str], |
| | images: list[Image.Image] | None = None, |
| | num_images_per_caption: int = 1, |
| | positive_prompt: str | None = None, |
| | negative_prompt: str | None = None, |
| | cfg: float = 1.0, |
| | cfg_img: float = 1.0, |
| | ): |
| | |
| | if not isinstance(captions, list): |
| | captions = [captions] |
| | captions = [caption for caption in captions for _ in range(num_images_per_caption)] |
| | if images is not None: |
| | images = [image for image in images for _ in range(num_images_per_caption)] |
| |
|
| | |
| | if positive_prompt is None: |
| | positive_prompt = "" |
| | captions = [f"{caption} {positive_prompt}" for caption in captions] |
| |
|
| | |
| | if negative_prompt is None: |
| | negative_prompt = "" |
| | num_samples = len(captions) |
| | if cfg != 1.0 and cfg_img != 1.0: |
| | w, h = images[0].size |
| | captions = ( |
| | captions + [self._image_str((h, w)) + negative_prompt] * num_samples |
| | ) |
| | images = images + images |
| | captions = captions + [negative_prompt] * num_samples |
| | elif cfg != 1.0 and cfg_img == 1.0: |
| | captions = captions + [negative_prompt] * num_samples |
| | elif cfg == 1.0 and cfg_img == 1.0: |
| | pass |
| |
|
| | return captions, images |
| |
|
| | def _add_prefix_ids(self, hw: tuple[int, int], input_ids: torch.Tensor, attention_mask: torch.Tensor): |
| | prefix_str = DEFAULT_IMAGE_AREA_TOKEN + hw2str(hw[0] // self.down_factor, hw[1] // self.down_factor) |
| | prefix_output = self.tokenizer(prefix_str, truncation=False, add_special_tokens=True, return_tensors="pt") |
| | prefix_input_ids = prefix_output.input_ids.to(input_ids.device, dtype=input_ids.dtype) |
| | prefix_attention_mask = prefix_output.attention_mask.to(attention_mask.device, dtype=attention_mask.dtype) |
| |
|
| | |
| | if self.tokenizer.bos_token is not None: |
| | prefix_input_ids = prefix_input_ids[:, 1:] |
| | prefix_attention_mask = prefix_attention_mask[:, 1:] |
| |
|
| | |
| | prefix_input_ids = torch.cat( |
| | [ |
| | prefix_input_ids, |
| | prefix_input_ids.new_tensor([self.model.config.boi]).unsqueeze(0), |
| | ], |
| | dim=1, |
| | ) |
| | prefix_attention_mask = torch.cat( |
| | [ |
| | prefix_attention_mask, |
| | prefix_attention_mask.new_ones((prefix_attention_mask.shape[0], 1)), |
| | ], |
| | dim=1, |
| | ) |
| |
|
| | bsz = input_ids.shape[0] |
| | input_ids = torch.cat([input_ids, prefix_input_ids.expand(bsz, -1)], dim=1) |
| | attention_mask = torch.cat([attention_mask, prefix_attention_mask.expand(bsz, -1)], dim=1) |
| |
|
| | return input_ids, attention_mask |
| |
|
| | @torch.no_grad() |
| | def decoding( |
| | self, |
| | c: torch.Tensor, |
| | attention_mask: torch.Tensor, |
| | past_key_values: Cache, |
| | max_new_len: int, |
| | num_images_per_caption: int, |
| | use_norm: bool = False, |
| | cfg: float = 1.0, |
| | cfg_img: float = 1.0, |
| | cfg_schedule: Literal["linear", "constant"] = "constant", |
| | timesteps_shift: float = 1.0, |
| | num_sampling_steps: int = 20, |
| | progress: bool = True, |
| | ): |
| | indices = list(range(max_new_len)) |
| | indices = tqdm(indices, unit="tokens") if progress else indices |
| |
|
| | tokens = None |
| | unnormed_tokens = None |
| | for _ in indices: |
| | |
| | if cfg_schedule == "linear": |
| | tokens_len = 0 if tokens is None else tokens.shape[1] |
| | cfg_iter = max(cfg / 2, 1 + (cfg - 1) * tokens_len / max_new_len) |
| | cfg_img_iter = max(cfg_img / 2, 1 + (cfg_img - 1) * tokens_len / max_new_len) |
| | elif cfg_schedule == "constant": |
| | cfg_iter = cfg |
| | cfg_img_iter = cfg_img |
| | else: |
| | raise NotImplementedError |
| |
|
| | c = self.model.image_out_projector(c) |
| | token_sampled = self.model.image_head.sample( |
| | c=c.squeeze(1), |
| | cfg=cfg_iter, |
| | cfg_img=cfg_img_iter, |
| | timesteps_shift=timesteps_shift, |
| | num_sampling_steps=num_sampling_steps, |
| | noise_repeat=num_images_per_caption, |
| | ) |
| |
|
| | unnormed_token_sampled = token_sampled.clone() |
| | if use_norm: |
| | token_sampled = layer_norm(token_sampled, normalized_shape=token_sampled.size()[1:]) |
| |
|
| | if tokens is not None: |
| | tokens = torch.cat([tokens, token_sampled.unsqueeze(1)], dim=1) |
| | unnormed_tokens = torch.cat([unnormed_tokens, unnormed_token_sampled.unsqueeze(1)], dim=1) |
| | else: |
| | tokens = token_sampled.unsqueeze(1) |
| | unnormed_tokens = unnormed_token_sampled.unsqueeze(1) |
| |
|
| | cur_inputs_embeds = self.model.image_in_projector(tokens[:, -1:]) |
| | if cfg != 1.0 and cfg_img == 1.0: |
| | cur_inputs_embeds = torch.cat([cur_inputs_embeds, cur_inputs_embeds], dim=0) |
| | elif cfg != 1.0 and cfg_img != 1.0: |
| | cur_inputs_embeds = torch.cat([cur_inputs_embeds, cur_inputs_embeds, cur_inputs_embeds], dim=0) |
| |
|
| | attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1) |
| | outputs = self.model.forward_model( |
| | inputs_embeds=cur_inputs_embeds, |
| | attention_mask=attention_mask, |
| | past_key_values=past_key_values, |
| | use_cache=True, |
| | ) |
| | past_key_values = outputs.past_key_values |
| | c = outputs.last_hidden_state[:, -1:] |
| |
|
| | return unnormed_tokens |
| |
|
| | @torch.no_grad() |
| | def generate_image( |
| | self, |
| | captions: str | list[str], |
| | images: list[Image.Image] | None = None, |
| | num_images_per_caption: int = 1, |
| | positive_prompt: str | None = None, |
| | negative_prompt: str | None = None, |
| | hw: tuple[int, int] = (256, 256), |
| | use_norm: bool = False, |
| | cfg: float = 1.0, |
| | cfg_img: float = 1.0, |
| | cfg_schedule: Literal["linear", "constant"] = "constant", |
| | num_sampling_steps: int = 20, |
| | timesteps_shift: float = 1.0, |
| | seed: int = 42, |
| | progress: bool = True, |
| | ) -> list[Image.Image]: |
| | |
| | captions, images = self._check_input(captions, images) |
| |
|
| | |
| | captions, images = self._build_captions( |
| | captions, images, num_images_per_caption, positive_prompt, negative_prompt, cfg, cfg_img |
| | ) |
| |
|
| | |
| | |
| | latents = None |
| | if images is not None: |
| | pixel_values = [self.pil2tensor(image) for image in images] |
| | pixel_values = torch.stack(pixel_values).to(self.device) |
| | with compile_manager.compile_disabled(): |
| | posterior = self.vae.encode(pixel_values.to(self.vae.dtype)).latent_dist |
| | latents = (posterior.sample() - self.shift_factor) * self.scaling_factor |
| |
|
| | if seed is not None: |
| | set_seed(seed) |
| | |
| | |
| | output = self.tokenizer(captions, padding="longest", truncation=False, add_special_tokens=True, return_tensors="pt", padding_side="left") |
| | input_ids = output.input_ids.to(self.device) |
| | attention_mask = output.attention_mask.to(self.device) |
| | input_ids, attention_mask = self._add_prefix_ids(hw, input_ids, attention_mask) |
| |
|
| | |
| | max_new_len = (hw[0] // self.down_factor) * (hw[1] // self.down_factor) |
| | max_cache_len = input_ids.shape[1] + max_new_len |
| | past_key_values = StaticCache( |
| | config=self.model.config, |
| | max_batch_size=input_ids.shape[0], |
| | max_cache_len=max_cache_len, |
| | device=self.device, |
| | dtype=self.dtype, |
| | ) |
| | inputs_embeds = self.model.prepare_inputs_embeds(input_ids, latents) |
| | with compile_manager.compile_disabled(): |
| | outputs = self.model.forward_model( |
| | inputs_embeds=inputs_embeds, |
| | attention_mask=attention_mask, |
| | past_key_values=past_key_values, |
| | use_cache=True, |
| | ) |
| | past_key_values = outputs.past_key_values |
| | c = outputs.last_hidden_state[:, -1:] |
| |
|
| | |
| | tokens = self.decoding( |
| | c=c, |
| | attention_mask=attention_mask, |
| | past_key_values=past_key_values, |
| | max_new_len=max_new_len, |
| | num_images_per_caption=num_images_per_caption, |
| | use_norm=use_norm, |
| | cfg=cfg, |
| | cfg_img=cfg_img, |
| | cfg_schedule=cfg_schedule, |
| | timesteps_shift=timesteps_shift, |
| | num_sampling_steps=num_sampling_steps, |
| | progress=progress, |
| | ) |
| |
|
| | |
| | latents = self.model.unpatchify(tokens, h=hw[0] // self.down_factor, w=hw[1] // self.down_factor) |
| | latents = (latents / self.scaling_factor) + self.shift_factor |
| |
|
| | |
| | with compile_manager.compile_disabled(): |
| | sampled_images = self.vae.decode(latents.to(self.vae.dtype)).sample |
| | sampled_images = sampled_images.detach().cpu().to(torch.float32) |
| | pil_images = [to_pil(img) for img in sampled_images] |
| |
|
| | return pil_images |