HunyuanImage-3.0-Instruct / image_processor.py
root
Fix LFS and upload model
036458a
# Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from dataclasses import dataclass, field, asdict
from typing import Tuple, Optional, Callable, Union, Any
import random
import math
import torch
from PIL import Image
from torchvision import transforms
from transformers.image_processing_utils import BaseImageProcessor
from transformers.image_utils import load_image
from transformers.models.siglip2.image_processing_siglip2_fast import Siglip2ImageProcessorFast
from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList
from .tokenization_hunyuan_image_3 import ImageInfo, ImageTensor, CondImage, Resolution, ResolutionGroup
InputImage = Union[Image.Image, str]
class SliceVocabLogitsProcessor(LogitsProcessor):
"""
[`LogitsProcessor`] that performs vocab slicing, i.e. restricting probabilities with in some range. This processor
is often used in multimodal discrete LLMs, which ensure that we only sample within one modality
Args:
vocab_start (`int`): start of slice, default None meaning from 0
vocab_end (`int`): end of slice, default None meaning to the end of list
when start and end are all None, this processor does noting
"""
def __init__(self, vocab_start: int = None, vocab_end: int = None, **kwargs):
if vocab_start is not None and vocab_end is not None:
assert vocab_start < vocab_end, f"Ensure vocab_start {vocab_start} < vocab_end {vocab_end}"
self.vocab_start = vocab_start
self.vocab_end = vocab_end
self.other_slices = kwargs.get("other_slices", [])
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
scores_processed = scores[:, self.vocab_start: self.vocab_end]
for other_slice in self.other_slices:
scores_processed = torch.cat([scores_processed, scores[:, other_slice[0]: other_slice[1]]], dim=-1)
return scores_processed
def __repr__(self):
return f"SliceVocabLogitsWarper(vocab_start={self.vocab_start}, vocab_end={self.vocab_end}, other_slices={self.other_slices})"
def resize_and_crop(image: Image.Image, target_size: Tuple[int, int], resample=Image.Resampling.LANCZOS, crop_type='center', crop_coords=None) -> Image.Image:
tw, th = target_size
w, h = image.size
tr = th / tw
r = h / w
if crop_type == "resize":
resize_width = tw
resize_height = th
crop_top = 0
crop_left = 0
image = image.resize((resize_width, resize_height), resample=resample)
else:
# maintain the aspect ratio
if r < tr:
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h))
if crop_type == 'center':
crop_top = int(round((resize_height - th) / 2.0))
crop_left = int(round((resize_width - tw) / 2.0))
elif crop_type == 'random':
crop_top = random.randint(0, resize_height - th)
crop_left = random.randint(0, resize_width - tw)
elif crop_type == 'fixed':
assert crop_coords is not None, 'crop_coords should be provided when crop_type is fixed.'
crop_left, crop_top = crop_coords
else:
raise ValueError(f'crop_type must be center, random or fixed, but got {crop_type}')
image = image.resize((resize_width, resize_height), resample=resample)
image = image.crop((crop_left, crop_top, crop_left + tw, crop_top + th))
return image
@dataclass
class ResolutionGroupConfig:
base_size: int = None
step: Optional[int] = None
align: int = 16
def to_dict(self):
return asdict(self)
@dataclass
class VAEInfo:
encoder_type: str
down_h_factor: int = -1
down_w_factor: int = -1
patch_size: int = 1
h_factor: int = -1
w_factor: int = -1
image_type: str = None
def __post_init__(self):
self.h_factor = self.down_h_factor * self.patch_size
self.w_factor = self.down_w_factor * self.patch_size
if self.image_type is None:
self.image_type = "vae"
@dataclass
class ViTInfo:
encoder_type: str
h_factor: int = -1
w_factor: int = -1
max_token_length: int = 0 # pad to max_token_length
processor: Callable = field(default_factory=BaseImageProcessor)
image_type: str = None
def __post_init__(self):
if self.image_type is None:
self.image_type = self.encoder_type.split("-")[0]
class HunyuanImage3ImageProcessor(object):
def __init__(self, config):
self.config = config
self.reso_group_config = ResolutionGroupConfig(base_size=config.image_base_size)
self.vae_reso_group = ResolutionGroup(
**self.reso_group_config.to_dict(),
extra_resolutions=[
Resolution("1024x768"),
Resolution("1280x720"),
Resolution("768x1024"),
Resolution("720x1280"),
]
)
self.img_ratio_slice_logits_processor = None
self.pil_image_to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]), # transform to [-1, 1]
])
self.vae_info = VAEInfo(
encoder_type=config.vae_type,
down_h_factor=config.vae_downsample_factor[0], down_w_factor=config.vae_downsample_factor[0],
patch_size=config.patch_size,
)
if config.vit_type == "siglip2-so400m-patch16-naflex":
self.vit_processor = Siglip2ImageProcessorFast.from_dict(config.vit_processor)
else:
raise ValueError(f"Unsupported vit_type: {config.vit_type}")
self.vit_info = ViTInfo(
encoder_type=config.vit_type,
h_factor=self.vit_processor.patch_size,
w_factor=self.vit_processor.patch_size,
max_token_length=self.vit_processor.max_num_patches,
processor=self.vit_processor,
)
self.cond_token_attn_type = config.cond_token_attn_type
self.cond_image_type = config.cond_image_type
def build_gen_image_info(self, image_size, add_guidance_token=False, add_timestep_r_token=False) -> ImageInfo:
# parse image size (HxW, H:W, or <img_ratio_i>)
if isinstance(image_size, str):
if image_size.startswith("<img_ratio_"):
ratio_index = int(image_size.split("_")[-1].rstrip(">"))
reso = self.vae_reso_group[ratio_index]
image_size = reso.height, reso.width
elif 'x' in image_size:
image_size = [int(s) for s in image_size.split('x')]
elif ':' in image_size:
image_size = [int(s) for s in image_size.split(':')]
assert len(image_size) == 2, f"`image_size` should be in the format of 'W:H', got {image_size}."
# Note that ratio is width:height
image_size = [image_size[1], image_size[0]]
else:
raise ValueError(
f"`image_size` should be in the format of 'HxW', 'W:H' or <img_ratio_i>, got {image_size}.")
assert len(image_size) == 2, f"`image_size` should be in the format of 'HxW', got {image_size}."
elif isinstance(image_size, (list, tuple)):
assert len(image_size) == 2 and all(isinstance(s, int) for s in image_size), \
f"`image_size` should be a tuple of two integers or a string in the format of 'HxW', got {image_size}."
else:
raise ValueError(f"`image_size` should be a tuple of two integers or a string in the format of 'WxH', "
f"got {image_size}.")
image_width, image_height = self.vae_reso_group.get_target_size(image_size[1], image_size[0])
token_height = image_height // self.vae_info.h_factor
token_width = image_width // self.vae_info.w_factor
base_size, ratio_idx = self.vae_reso_group.get_base_size_and_ratio_index(image_size[1], image_size[0])
image_info = ImageInfo(
image_type="gen_image", image_width=image_width, image_height=image_height,
token_width=token_width, token_height=token_height, base_size=base_size, ratio_index=ratio_idx,
add_guidance_token=add_guidance_token, add_timestep_r_token=add_timestep_r_token,
)
return image_info
def as_image_tensor(self, image, image_type, **kwargs) -> ImageTensor:
if isinstance(image, Image.Image):
tensor = self.pil_image_to_tensor(image)
else:
tensor = image
origin_size = kwargs["origin_size"]
ori_image_width = origin_size[0]
ori_image_height = origin_size[1]
if image_type == "vae":
assert tensor.ndim == 3 or tensor.ndim == 4
h, w = tensor.shape[-2], tensor.shape[-1]
assert (h % self.vae_info.h_factor == 0 and w % self.vae_info.w_factor == 0), \
(f"Image size should be divisible by ({self.vae_info.h_factor}, {self.vae_info.w_factor}), "
f"but got ({h} x {w}).")
tk_height = h // self.vae_info.h_factor
tk_width = w // self.vae_info.w_factor
base_size, ratio_idx = self.vae_reso_group.get_base_size_and_ratio_index(w, h)
tensor.i = ImageInfo(
image_type=image_type,
image_width=w, image_height=h, token_width=tk_width, token_height=tk_height,
base_size=base_size, ratio_index=ratio_idx,
ori_image_width=ori_image_width,
ori_image_height=ori_image_height,
)
tensor.section_type = "cond_vae_image"
elif image_type == "siglip2":
spatial_shapes = kwargs["spatial_shapes"] # 2 (h, w)
pixel_attention_mask = kwargs["pixel_attention_mask"] # seq_len
tensor.i = ImageInfo(
image_type=image_type,
image_width=spatial_shapes[1].item() * self.vit_info.w_factor,
image_height=spatial_shapes[0].item() * self.vit_info.h_factor,
token_width=spatial_shapes[1].item(),
token_height=spatial_shapes[0].item(),
image_token_length=self.vit_info.max_token_length,
ori_image_width=ori_image_width,
ori_image_height=ori_image_height,
)
tensor.section_type = "cond_vit_image"
tensor.vision_encoder_kwargs = {
"spatial_shapes": spatial_shapes,
"pixel_attention_mask": pixel_attention_mask,
}
elif image_type == "anyres":
token_width = kwargs["resized_image_width"] // self.vit_info.w_factor
token_height = kwargs["resized_image_height"] // self.vit_info.h_factor
tensor.i = ImageInfo(
image_type=image_type,
image_width=kwargs["resized_image_width"],
image_height=kwargs["resized_image_height"],
token_width=token_width,
token_height=token_height,
image_token_length=token_height * (token_width + 1) + 2,
)
tensor.section_type = "cond_vit_image"
else:
raise ValueError(f"Unknown image type: {image_type}")
return tensor
def vae_process_image(self, image, target_size, random_crop: bool | str = False) -> ImageTensor:
origin_size = image.size
crop_type = random_crop if isinstance(random_crop, str) else ("random" if random_crop else "center")
resized_image = resize_and_crop(image, target_size, crop_type=crop_type)
return self.as_image_tensor(resized_image, image_type=self.vae_info.image_type, origin_size=origin_size)
def vit_process_image(self, image) -> ImageTensor:
origin_size = image.size
inputs = self.vit_info.processor(image)
image = inputs["pixel_values"].squeeze(0) # (seq_len, dim)
remain_keys = set(inputs.keys()) - {"pixel_values"}
remain_kwargs = {}
for key in remain_keys:
if isinstance(inputs[key], torch.Tensor):
remain_kwargs[key] = inputs[key].squeeze(0)
else:
remain_kwargs[key] = inputs[key]
return self.as_image_tensor(image, image_type=self.vit_info.image_type, origin_size=origin_size, **remain_kwargs)
def get_image_with_size(
self,
src: InputImage,
random_crop: bool | str = False,
return_type: str = "vae",
) -> tuple[ImageTensor | CondImage, bool]:
""" For various image generation tasks, dynamic image sizes """
image = load_image(src)
image_flag = "normal"
img_success = image_flag != "gray"
origin_size = image.size # (w_ori, h_ori)
if "vae" in return_type:
target_size = self.vae_reso_group.get_target_size(*origin_size)
vae_image_tensor = self.vae_process_image(image, target_size, random_crop=random_crop)
else:
vae_image_tensor = None
if "vit" in return_type:
vit_image_tensor = self.vit_process_image(image)
else:
vit_image_tensor = None
if return_type == "vae":
image_tensor = vae_image_tensor
elif return_type == "vit":
image_tensor = vit_image_tensor
elif return_type == "vae_vit":
image_tensor = CondImage(image_type=return_type, vae_image=vae_image_tensor, vit_image=vit_image_tensor)
else:
raise ValueError(f"Unknown return_type: {return_type}")
return image_tensor, img_success
def build_cond_images(
self,
image_list: Optional[list[InputImage]] = None,
message_list: Optional[list[dict[str, Any]]] = None,
infer_align_image_size: bool = False,
) -> Optional[list[CondImage]]:
if image_list is not None and message_list is not None:
raise ValueError("`image_list` and `message_list` cannot be provided at the same time.")
if message_list is not None:
image_list = []
for message in message_list:
visuals = [
content
for content in message["content"]
if isinstance(content, dict) and content["type"] in ["image"]
]
image_list.extend([
vision_info[key]
for vision_info in visuals
for key in ["image", "url", "path", "base64"]
if key in vision_info and vision_info["type"] == "image"
])
if infer_align_image_size:
random_crop = "resize"
else:
random_crop = "center"
return [
self.get_image_with_size(src, return_type=self.cond_image_type, random_crop=random_crop)[0]
for src in image_list
]
def prepare_full_attn_slices(self, output, batch_idx=None, with_gen=True):
""" Determine full attention image slices according to strategies. """
if self.cond_image_type == "vae":
cond_choices = dict(
causal=[],
full=output.vae_image_slices[batch_idx] if batch_idx is not None else output.vae_image_slices
)
elif self.cond_image_type == "vit":
cond_choices = dict(
causal=[],
full=output.vit_image_slices[batch_idx] if batch_idx is not None else output.vit_image_slices
)
elif self.cond_image_type == "vae_vit":
cond_choices = {
"causal": [],
"full": (
output.vae_image_slices[batch_idx] + output.vit_image_slices[batch_idx]
if batch_idx is not None
else output.vae_image_slices + output.vit_image_slices
),
"joint_full": (
output.joint_image_slices[batch_idx]
if batch_idx is not None
else output.joint_image_slices
),
"full_causal": (
output.vae_image_slices[batch_idx]
if batch_idx is not None
else output.vae_image_slices
),
}
else:
raise ValueError(f"Unknown cond_image_type: {self.cond_image_type}")
slices = cond_choices[self.cond_token_attn_type]
if with_gen:
gen_image_slices = (
output.gen_image_slices[batch_idx]
if batch_idx is not None
else output.gen_image_slices
)
slices = slices + gen_image_slices
return slices
def build_img_ratio_slice_logits_processor(self, tokenizer):
if self.img_ratio_slice_logits_processor is None:
self.img_ratio_slice_logits_processor = LogitsProcessorList()
self.img_ratio_slice_logits_processor.append(
SliceVocabLogitsProcessor(
vocab_start=tokenizer.start_ratio_token_id,
vocab_end=tokenizer.end_ratio_token_id + 1,
other_slices=getattr(tokenizer, "ratio_token_other_slices", []),
)
)
def postprocess_outputs(self, outputs: list[Image.Image], batch_cond_images, infer_align_image_size: bool = False):
if infer_align_image_size:
target_area = self.vae_reso_group.base_size ** 2
for batch_index, (output_image, cond_images) in enumerate(zip(outputs, batch_cond_images)):
output_image_ratio_index = self.vae_reso_group.get_base_size_and_ratio_index(width=output_image.width, height=output_image.height)[1]
cond_images_ratio_index_list = []
cond_images_ori_width_list = []
cond_images_ori_height_list = []
for cond_image in cond_images:
if isinstance(cond_image, ImageTensor):
cond_images_ratio_index_list.append(cond_image.i.ratio_index)
cond_images_ori_width_list.append(cond_image.i.ori_image_width)
cond_images_ori_height_list.append(cond_image.i.ori_image_height)
else: # CondImage
cond_images_ratio_index_list.append(cond_image.vae_image.i.ratio_index)
cond_images_ori_width_list.append(cond_image.vae_image.i.ori_image_width)
cond_images_ori_height_list.append(cond_image.vae_image.i.ori_image_height)
if len(cond_images) == 0:
continue
elif len(cond_images) == 1:
if output_image_ratio_index == cond_images_ratio_index_list[0]:
if abs(cond_images_ori_height_list[0] / cond_images_ori_width_list[0] - self.vae_reso_group[output_image_ratio_index].ratio) >= 0.01:
scale = math.sqrt(target_area / (cond_images_ori_width_list[0] * cond_images_ori_height_list[0]))
new_w = round(cond_images_ori_width_list[0] * scale)
new_h = round(cond_images_ori_height_list[0] * scale)
outputs[batch_index] = output_image.resize((new_w, new_h), resample=Image.Resampling.LANCZOS)
else:
for cond_image_ratio_index, cond_image_ori_width, cond_image_ori_height in zip(cond_images_ratio_index_list, cond_images_ori_width_list, cond_images_ori_height_list):
if output_image_ratio_index == cond_image_ratio_index:
if abs(cond_image_ori_height / cond_image_ori_width - self.vae_reso_group[output_image_ratio_index].ratio) >= 0.01:
scale = math.sqrt(target_area / (cond_image_ori_width * cond_image_ori_height))
new_w = round(cond_image_ori_width * scale)
new_h = round(cond_image_ori_height * scale)
outputs[batch_index] = output_image.resize((new_w, new_h), resample=Image.Resampling.LANCZOS)
break
return outputs
__all__ = [
"HunyuanImage3ImageProcessor"
]