| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from dataclasses import dataclass, field, asdict |
| | from typing import Tuple, Optional, Callable, Union, Any |
| | import random |
| | import math |
| |
|
| | import torch |
| | from PIL import Image |
| | from torchvision import transforms |
| | from transformers.image_processing_utils import BaseImageProcessor |
| | from transformers.image_utils import load_image |
| | from transformers.models.siglip2.image_processing_siglip2_fast import Siglip2ImageProcessorFast |
| | from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList |
| |
|
| | from .tokenization_hunyuan_image_3 import ImageInfo, ImageTensor, CondImage, Resolution, ResolutionGroup |
| |
|
| | InputImage = Union[Image.Image, str] |
| |
|
| |
|
| | class SliceVocabLogitsProcessor(LogitsProcessor): |
| | """ |
| | [`LogitsProcessor`] that performs vocab slicing, i.e. restricting probabilities with in some range. This processor |
| | is often used in multimodal discrete LLMs, which ensure that we only sample within one modality |
| | |
| | Args: |
| | vocab_start (`int`): start of slice, default None meaning from 0 |
| | vocab_end (`int`): end of slice, default None meaning to the end of list |
| | when start and end are all None, this processor does noting |
| | |
| | """ |
| |
|
| | def __init__(self, vocab_start: int = None, vocab_end: int = None, **kwargs): |
| | if vocab_start is not None and vocab_end is not None: |
| | assert vocab_start < vocab_end, f"Ensure vocab_start {vocab_start} < vocab_end {vocab_end}" |
| | self.vocab_start = vocab_start |
| | self.vocab_end = vocab_end |
| | self.other_slices = kwargs.get("other_slices", []) |
| |
|
| | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
| | scores_processed = scores[:, self.vocab_start: self.vocab_end] |
| | for other_slice in self.other_slices: |
| | scores_processed = torch.cat([scores_processed, scores[:, other_slice[0]: other_slice[1]]], dim=-1) |
| | return scores_processed |
| |
|
| | def __repr__(self): |
| | return f"SliceVocabLogitsWarper(vocab_start={self.vocab_start}, vocab_end={self.vocab_end}, other_slices={self.other_slices})" |
| |
|
| |
|
| | def resize_and_crop(image: Image.Image, target_size: Tuple[int, int], resample=Image.Resampling.LANCZOS, crop_type='center', crop_coords=None) -> Image.Image: |
| | tw, th = target_size |
| | w, h = image.size |
| |
|
| | tr = th / tw |
| | r = h / w |
| |
|
| | if crop_type == "resize": |
| | resize_width = tw |
| | resize_height = th |
| | crop_top = 0 |
| | crop_left = 0 |
| | image = image.resize((resize_width, resize_height), resample=resample) |
| | else: |
| | |
| | if r < tr: |
| | resize_height = th |
| | resize_width = int(round(th / h * w)) |
| | else: |
| | resize_width = tw |
| | resize_height = int(round(tw / w * h)) |
| |
|
| | if crop_type == 'center': |
| | crop_top = int(round((resize_height - th) / 2.0)) |
| | crop_left = int(round((resize_width - tw) / 2.0)) |
| | elif crop_type == 'random': |
| | crop_top = random.randint(0, resize_height - th) |
| | crop_left = random.randint(0, resize_width - tw) |
| | elif crop_type == 'fixed': |
| | assert crop_coords is not None, 'crop_coords should be provided when crop_type is fixed.' |
| | crop_left, crop_top = crop_coords |
| | else: |
| | raise ValueError(f'crop_type must be center, random or fixed, but got {crop_type}') |
| |
|
| | image = image.resize((resize_width, resize_height), resample=resample) |
| | image = image.crop((crop_left, crop_top, crop_left + tw, crop_top + th)) |
| |
|
| | return image |
| |
|
| |
|
| | @dataclass |
| | class ResolutionGroupConfig: |
| | base_size: int = None |
| | step: Optional[int] = None |
| | align: int = 16 |
| |
|
| | def to_dict(self): |
| | return asdict(self) |
| |
|
| |
|
| | @dataclass |
| | class VAEInfo: |
| | encoder_type: str |
| | down_h_factor: int = -1 |
| | down_w_factor: int = -1 |
| | patch_size: int = 1 |
| | h_factor: int = -1 |
| | w_factor: int = -1 |
| | image_type: str = None |
| |
|
| | def __post_init__(self): |
| | self.h_factor = self.down_h_factor * self.patch_size |
| | self.w_factor = self.down_w_factor * self.patch_size |
| | if self.image_type is None: |
| | self.image_type = "vae" |
| |
|
| |
|
| | @dataclass |
| | class ViTInfo: |
| | encoder_type: str |
| | h_factor: int = -1 |
| | w_factor: int = -1 |
| | max_token_length: int = 0 |
| | processor: Callable = field(default_factory=BaseImageProcessor) |
| | image_type: str = None |
| |
|
| | def __post_init__(self): |
| | if self.image_type is None: |
| | self.image_type = self.encoder_type.split("-")[0] |
| |
|
| |
|
| | class HunyuanImage3ImageProcessor(object): |
| | def __init__(self, config): |
| | self.config = config |
| |
|
| | self.reso_group_config = ResolutionGroupConfig(base_size=config.image_base_size) |
| | self.vae_reso_group = ResolutionGroup( |
| | **self.reso_group_config.to_dict(), |
| | extra_resolutions=[ |
| | Resolution("1024x768"), |
| | Resolution("1280x720"), |
| | Resolution("768x1024"), |
| | Resolution("720x1280"), |
| | ] |
| | ) |
| | self.img_ratio_slice_logits_processor = None |
| | self.pil_image_to_tensor = transforms.Compose([ |
| | transforms.ToTensor(), |
| | transforms.Normalize([0.5], [0.5]), |
| | ]) |
| | self.vae_info = VAEInfo( |
| | encoder_type=config.vae_type, |
| | down_h_factor=config.vae_downsample_factor[0], down_w_factor=config.vae_downsample_factor[0], |
| | patch_size=config.patch_size, |
| | ) |
| |
|
| | if config.vit_type == "siglip2-so400m-patch16-naflex": |
| | self.vit_processor = Siglip2ImageProcessorFast.from_dict(config.vit_processor) |
| | else: |
| | raise ValueError(f"Unsupported vit_type: {config.vit_type}") |
| | self.vit_info = ViTInfo( |
| | encoder_type=config.vit_type, |
| | h_factor=self.vit_processor.patch_size, |
| | w_factor=self.vit_processor.patch_size, |
| | max_token_length=self.vit_processor.max_num_patches, |
| | processor=self.vit_processor, |
| | ) |
| | self.cond_token_attn_type = config.cond_token_attn_type |
| | self.cond_image_type = config.cond_image_type |
| |
|
| | def build_gen_image_info(self, image_size, add_guidance_token=False, add_timestep_r_token=False) -> ImageInfo: |
| | |
| | if isinstance(image_size, str): |
| | if image_size.startswith("<img_ratio_"): |
| | ratio_index = int(image_size.split("_")[-1].rstrip(">")) |
| | reso = self.vae_reso_group[ratio_index] |
| | image_size = reso.height, reso.width |
| | elif 'x' in image_size: |
| | image_size = [int(s) for s in image_size.split('x')] |
| | elif ':' in image_size: |
| | image_size = [int(s) for s in image_size.split(':')] |
| | assert len(image_size) == 2, f"`image_size` should be in the format of 'W:H', got {image_size}." |
| | |
| | image_size = [image_size[1], image_size[0]] |
| | else: |
| | raise ValueError( |
| | f"`image_size` should be in the format of 'HxW', 'W:H' or <img_ratio_i>, got {image_size}.") |
| | assert len(image_size) == 2, f"`image_size` should be in the format of 'HxW', got {image_size}." |
| | elif isinstance(image_size, (list, tuple)): |
| | assert len(image_size) == 2 and all(isinstance(s, int) for s in image_size), \ |
| | f"`image_size` should be a tuple of two integers or a string in the format of 'HxW', got {image_size}." |
| | else: |
| | raise ValueError(f"`image_size` should be a tuple of two integers or a string in the format of 'WxH', " |
| | f"got {image_size}.") |
| | image_width, image_height = self.vae_reso_group.get_target_size(image_size[1], image_size[0]) |
| | token_height = image_height // self.vae_info.h_factor |
| | token_width = image_width // self.vae_info.w_factor |
| | base_size, ratio_idx = self.vae_reso_group.get_base_size_and_ratio_index(image_size[1], image_size[0]) |
| | image_info = ImageInfo( |
| | image_type="gen_image", image_width=image_width, image_height=image_height, |
| | token_width=token_width, token_height=token_height, base_size=base_size, ratio_index=ratio_idx, |
| | add_guidance_token=add_guidance_token, add_timestep_r_token=add_timestep_r_token, |
| | ) |
| | return image_info |
| |
|
| | def as_image_tensor(self, image, image_type, **kwargs) -> ImageTensor: |
| | if isinstance(image, Image.Image): |
| | tensor = self.pil_image_to_tensor(image) |
| | else: |
| | tensor = image |
| | |
| | origin_size = kwargs["origin_size"] |
| | ori_image_width = origin_size[0] |
| | ori_image_height = origin_size[1] |
| |
|
| | if image_type == "vae": |
| | assert tensor.ndim == 3 or tensor.ndim == 4 |
| | h, w = tensor.shape[-2], tensor.shape[-1] |
| | assert (h % self.vae_info.h_factor == 0 and w % self.vae_info.w_factor == 0), \ |
| | (f"Image size should be divisible by ({self.vae_info.h_factor}, {self.vae_info.w_factor}), " |
| | f"but got ({h} x {w}).") |
| | tk_height = h // self.vae_info.h_factor |
| | tk_width = w // self.vae_info.w_factor |
| | base_size, ratio_idx = self.vae_reso_group.get_base_size_and_ratio_index(w, h) |
| | tensor.i = ImageInfo( |
| | image_type=image_type, |
| | image_width=w, image_height=h, token_width=tk_width, token_height=tk_height, |
| | base_size=base_size, ratio_index=ratio_idx, |
| | ori_image_width=ori_image_width, |
| | ori_image_height=ori_image_height, |
| | ) |
| | tensor.section_type = "cond_vae_image" |
| | elif image_type == "siglip2": |
| | spatial_shapes = kwargs["spatial_shapes"] |
| | pixel_attention_mask = kwargs["pixel_attention_mask"] |
| | tensor.i = ImageInfo( |
| | image_type=image_type, |
| | image_width=spatial_shapes[1].item() * self.vit_info.w_factor, |
| | image_height=spatial_shapes[0].item() * self.vit_info.h_factor, |
| | token_width=spatial_shapes[1].item(), |
| | token_height=spatial_shapes[0].item(), |
| | image_token_length=self.vit_info.max_token_length, |
| | ori_image_width=ori_image_width, |
| | ori_image_height=ori_image_height, |
| | ) |
| | tensor.section_type = "cond_vit_image" |
| | tensor.vision_encoder_kwargs = { |
| | "spatial_shapes": spatial_shapes, |
| | "pixel_attention_mask": pixel_attention_mask, |
| | } |
| | elif image_type == "anyres": |
| | token_width = kwargs["resized_image_width"] // self.vit_info.w_factor |
| | token_height = kwargs["resized_image_height"] // self.vit_info.h_factor |
| | tensor.i = ImageInfo( |
| | image_type=image_type, |
| | image_width=kwargs["resized_image_width"], |
| | image_height=kwargs["resized_image_height"], |
| | token_width=token_width, |
| | token_height=token_height, |
| | image_token_length=token_height * (token_width + 1) + 2, |
| | ) |
| | tensor.section_type = "cond_vit_image" |
| | else: |
| | raise ValueError(f"Unknown image type: {image_type}") |
| | return tensor |
| |
|
| | def vae_process_image(self, image, target_size, random_crop: bool | str = False) -> ImageTensor: |
| | origin_size = image.size |
| | crop_type = random_crop if isinstance(random_crop, str) else ("random" if random_crop else "center") |
| | resized_image = resize_and_crop(image, target_size, crop_type=crop_type) |
| | return self.as_image_tensor(resized_image, image_type=self.vae_info.image_type, origin_size=origin_size) |
| |
|
| | def vit_process_image(self, image) -> ImageTensor: |
| | origin_size = image.size |
| | inputs = self.vit_info.processor(image) |
| | image = inputs["pixel_values"].squeeze(0) |
| |
|
| | remain_keys = set(inputs.keys()) - {"pixel_values"} |
| | remain_kwargs = {} |
| | for key in remain_keys: |
| | if isinstance(inputs[key], torch.Tensor): |
| | remain_kwargs[key] = inputs[key].squeeze(0) |
| | else: |
| | remain_kwargs[key] = inputs[key] |
| |
|
| | return self.as_image_tensor(image, image_type=self.vit_info.image_type, origin_size=origin_size, **remain_kwargs) |
| |
|
| | def get_image_with_size( |
| | self, |
| | src: InputImage, |
| | random_crop: bool | str = False, |
| | return_type: str = "vae", |
| | ) -> tuple[ImageTensor | CondImage, bool]: |
| | """ For various image generation tasks, dynamic image sizes """ |
| | image = load_image(src) |
| | image_flag = "normal" |
| | img_success = image_flag != "gray" |
| | origin_size = image.size |
| |
|
| | if "vae" in return_type: |
| | target_size = self.vae_reso_group.get_target_size(*origin_size) |
| | vae_image_tensor = self.vae_process_image(image, target_size, random_crop=random_crop) |
| | else: |
| | vae_image_tensor = None |
| |
|
| | if "vit" in return_type: |
| | vit_image_tensor = self.vit_process_image(image) |
| | else: |
| | vit_image_tensor = None |
| |
|
| | if return_type == "vae": |
| | image_tensor = vae_image_tensor |
| | elif return_type == "vit": |
| | image_tensor = vit_image_tensor |
| | elif return_type == "vae_vit": |
| | image_tensor = CondImage(image_type=return_type, vae_image=vae_image_tensor, vit_image=vit_image_tensor) |
| | else: |
| | raise ValueError(f"Unknown return_type: {return_type}") |
| |
|
| | return image_tensor, img_success |
| |
|
| | def build_cond_images( |
| | self, |
| | image_list: Optional[list[InputImage]] = None, |
| | message_list: Optional[list[dict[str, Any]]] = None, |
| | infer_align_image_size: bool = False, |
| | ) -> Optional[list[CondImage]]: |
| | if image_list is not None and message_list is not None: |
| | raise ValueError("`image_list` and `message_list` cannot be provided at the same time.") |
| | if message_list is not None: |
| | image_list = [] |
| | for message in message_list: |
| | visuals = [ |
| | content |
| | for content in message["content"] |
| | if isinstance(content, dict) and content["type"] in ["image"] |
| | ] |
| | image_list.extend([ |
| | vision_info[key] |
| | for vision_info in visuals |
| | for key in ["image", "url", "path", "base64"] |
| | if key in vision_info and vision_info["type"] == "image" |
| | ]) |
| |
|
| | if infer_align_image_size: |
| | random_crop = "resize" |
| | else: |
| | random_crop = "center" |
| |
|
| | return [ |
| | self.get_image_with_size(src, return_type=self.cond_image_type, random_crop=random_crop)[0] |
| | for src in image_list |
| | ] |
| |
|
| | def prepare_full_attn_slices(self, output, batch_idx=None, with_gen=True): |
| | """ Determine full attention image slices according to strategies. """ |
| | if self.cond_image_type == "vae": |
| | cond_choices = dict( |
| | causal=[], |
| | full=output.vae_image_slices[batch_idx] if batch_idx is not None else output.vae_image_slices |
| | ) |
| |
|
| | elif self.cond_image_type == "vit": |
| | cond_choices = dict( |
| | causal=[], |
| | full=output.vit_image_slices[batch_idx] if batch_idx is not None else output.vit_image_slices |
| | ) |
| |
|
| | elif self.cond_image_type == "vae_vit": |
| | cond_choices = { |
| | "causal": [], |
| | "full": ( |
| | output.vae_image_slices[batch_idx] + output.vit_image_slices[batch_idx] |
| | if batch_idx is not None |
| | else output.vae_image_slices + output.vit_image_slices |
| | ), |
| | "joint_full": ( |
| | output.joint_image_slices[batch_idx] |
| | if batch_idx is not None |
| | else output.joint_image_slices |
| | ), |
| | "full_causal": ( |
| | output.vae_image_slices[batch_idx] |
| | if batch_idx is not None |
| | else output.vae_image_slices |
| | ), |
| | } |
| |
|
| | else: |
| | raise ValueError(f"Unknown cond_image_type: {self.cond_image_type}") |
| | slices = cond_choices[self.cond_token_attn_type] |
| |
|
| | if with_gen: |
| | gen_image_slices = ( |
| | output.gen_image_slices[batch_idx] |
| | if batch_idx is not None |
| | else output.gen_image_slices |
| | ) |
| | slices = slices + gen_image_slices |
| | return slices |
| |
|
| | def build_img_ratio_slice_logits_processor(self, tokenizer): |
| | if self.img_ratio_slice_logits_processor is None: |
| | self.img_ratio_slice_logits_processor = LogitsProcessorList() |
| | self.img_ratio_slice_logits_processor.append( |
| | SliceVocabLogitsProcessor( |
| | vocab_start=tokenizer.start_ratio_token_id, |
| | vocab_end=tokenizer.end_ratio_token_id + 1, |
| | other_slices=getattr(tokenizer, "ratio_token_other_slices", []), |
| | ) |
| | ) |
| |
|
| | def postprocess_outputs(self, outputs: list[Image.Image], batch_cond_images, infer_align_image_size: bool = False): |
| | if infer_align_image_size: |
| | target_area = self.vae_reso_group.base_size ** 2 |
| |
|
| | for batch_index, (output_image, cond_images) in enumerate(zip(outputs, batch_cond_images)): |
| | output_image_ratio_index = self.vae_reso_group.get_base_size_and_ratio_index(width=output_image.width, height=output_image.height)[1] |
| | cond_images_ratio_index_list = [] |
| | cond_images_ori_width_list = [] |
| | cond_images_ori_height_list = [] |
| | for cond_image in cond_images: |
| | if isinstance(cond_image, ImageTensor): |
| | cond_images_ratio_index_list.append(cond_image.i.ratio_index) |
| | cond_images_ori_width_list.append(cond_image.i.ori_image_width) |
| | cond_images_ori_height_list.append(cond_image.i.ori_image_height) |
| | else: |
| | cond_images_ratio_index_list.append(cond_image.vae_image.i.ratio_index) |
| | cond_images_ori_width_list.append(cond_image.vae_image.i.ori_image_width) |
| | cond_images_ori_height_list.append(cond_image.vae_image.i.ori_image_height) |
| |
|
| | if len(cond_images) == 0: |
| | continue |
| | elif len(cond_images) == 1: |
| | if output_image_ratio_index == cond_images_ratio_index_list[0]: |
| | if abs(cond_images_ori_height_list[0] / cond_images_ori_width_list[0] - self.vae_reso_group[output_image_ratio_index].ratio) >= 0.01: |
| | scale = math.sqrt(target_area / (cond_images_ori_width_list[0] * cond_images_ori_height_list[0])) |
| | new_w = round(cond_images_ori_width_list[0] * scale) |
| | new_h = round(cond_images_ori_height_list[0] * scale) |
| | outputs[batch_index] = output_image.resize((new_w, new_h), resample=Image.Resampling.LANCZOS) |
| | else: |
| | for cond_image_ratio_index, cond_image_ori_width, cond_image_ori_height in zip(cond_images_ratio_index_list, cond_images_ori_width_list, cond_images_ori_height_list): |
| | if output_image_ratio_index == cond_image_ratio_index: |
| | if abs(cond_image_ori_height / cond_image_ori_width - self.vae_reso_group[output_image_ratio_index].ratio) >= 0.01: |
| | scale = math.sqrt(target_area / (cond_image_ori_width * cond_image_ori_height)) |
| | new_w = round(cond_image_ori_width * scale) |
| | new_h = round(cond_image_ori_height * scale) |
| | outputs[batch_index] = output_image.resize((new_w, new_h), resample=Image.Resampling.LANCZOS) |
| | break |
| |
|
| | return outputs |
| |
|
| | __all__ = [ |
| | "HunyuanImage3ImageProcessor" |
| | ] |
| |
|