| | import os |
| | import logging |
| | from types import SimpleNamespace |
| | from typing import Optional, Union |
| |
|
| | import accelerate |
| | from accelerate import Accelerator, init_empty_weights |
| | import torch |
| | from safetensors.torch import load_file |
| | from transformers import ( |
| | LlamaTokenizerFast, |
| | LlamaConfig, |
| | LlamaModel, |
| | CLIPTokenizer, |
| | CLIPTextModel, |
| | CLIPConfig, |
| | SiglipImageProcessor, |
| | SiglipVisionModel, |
| | SiglipVisionConfig, |
| | ) |
| |
|
| | from utils.safetensors_utils import load_split_weights |
| | from hunyuan_model.vae import load_vae as hunyuan_load_vae |
| |
|
| | import logging |
| |
|
| | logger = logging.getLogger(__name__) |
| | logging.basicConfig(level=logging.INFO) |
| |
|
| |
|
| | def load_vae( |
| | vae_path: str, vae_chunk_size: Optional[int], vae_spatial_tile_sample_min_size: Optional[int], device: Union[str, torch.device] |
| | ): |
| | |
| | if os.path.isdir(vae_path): |
| | vae_path = os.path.join(vae_path, "vae", "diffusion_pytorch_model.safetensors") |
| | else: |
| | vae_path = vae_path |
| |
|
| | vae_dtype = torch.float16 |
| | vae, _, s_ratio, t_ratio = hunyuan_load_vae(vae_dtype=vae_dtype, device=device, vae_path=vae_path) |
| | vae.eval() |
| | |
| |
|
| | |
| | chunk_size = vae_chunk_size |
| | if chunk_size is not None: |
| | vae.set_chunk_size_for_causal_conv_3d(chunk_size) |
| | logger.info(f"Set chunk_size to {chunk_size} for CausalConv3d") |
| |
|
| | if vae_spatial_tile_sample_min_size is not None: |
| | vae.enable_spatial_tiling(True) |
| | vae.tile_sample_min_size = vae_spatial_tile_sample_min_size |
| | vae.tile_latent_min_size = vae_spatial_tile_sample_min_size // 8 |
| | logger.info(f"Enabled spatial tiling with min size {vae_spatial_tile_sample_min_size}") |
| | |
| | else: |
| | vae.enable_spatial_tiling(True) |
| |
|
| | return vae |
| |
|
| |
|
| | |
| |
|
| | |
| |
|
| | LLAMA_CONFIG = { |
| | "architectures": ["LlamaModel"], |
| | "attention_bias": False, |
| | "attention_dropout": 0.0, |
| | "bos_token_id": 128000, |
| | "eos_token_id": 128001, |
| | "head_dim": 128, |
| | "hidden_act": "silu", |
| | "hidden_size": 4096, |
| | "initializer_range": 0.02, |
| | "intermediate_size": 14336, |
| | "max_position_embeddings": 8192, |
| | "mlp_bias": False, |
| | "model_type": "llama", |
| | "num_attention_heads": 32, |
| | "num_hidden_layers": 32, |
| | "num_key_value_heads": 8, |
| | "pretraining_tp": 1, |
| | "rms_norm_eps": 1e-05, |
| | "rope_scaling": None, |
| | "rope_theta": 500000.0, |
| | "tie_word_embeddings": False, |
| | "torch_dtype": "float16", |
| | "transformers_version": "4.46.3", |
| | "use_cache": True, |
| | "vocab_size": 128320, |
| | } |
| |
|
| | CLIP_CONFIG = { |
| | |
| | "architectures": ["CLIPTextModel"], |
| | "attention_dropout": 0.0, |
| | "bos_token_id": 0, |
| | "dropout": 0.0, |
| | "eos_token_id": 2, |
| | "hidden_act": "quick_gelu", |
| | "hidden_size": 768, |
| | "initializer_factor": 1.0, |
| | "initializer_range": 0.02, |
| | "intermediate_size": 3072, |
| | "layer_norm_eps": 1e-05, |
| | "max_position_embeddings": 77, |
| | "model_type": "clip_text_model", |
| | "num_attention_heads": 12, |
| | "num_hidden_layers": 12, |
| | "pad_token_id": 1, |
| | "projection_dim": 768, |
| | "torch_dtype": "float16", |
| | "transformers_version": "4.48.0.dev0", |
| | "vocab_size": 49408, |
| | } |
| |
|
| |
|
| | def load_text_encoder1( |
| | args, fp8_llm: Optional[bool] = False, device: Optional[Union[str, torch.device]] = None |
| | ) -> tuple[LlamaTokenizerFast, LlamaModel]: |
| | |
| | logger.info(f"Loading text encoder 1 tokenizer") |
| | tokenizer1 = LlamaTokenizerFast.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="tokenizer") |
| |
|
| | logger.info(f"Loading text encoder 1 from {args.text_encoder1}") |
| | if os.path.isdir(args.text_encoder1): |
| | |
| | text_encoder1 = LlamaModel.from_pretrained(args.text_encoder1, subfolder="text_encoder", torch_dtype=torch.float16) |
| | else: |
| | |
| | config = LlamaConfig(**LLAMA_CONFIG) |
| | with init_empty_weights(): |
| | text_encoder1 = LlamaModel._from_config(config, torch_dtype=torch.float16) |
| |
|
| | state_dict = load_split_weights(args.text_encoder1) |
| |
|
| | |
| | if "model.embed_tokens.weight" in state_dict: |
| | for key in list(state_dict.keys()): |
| | if key.startswith("model."): |
| | new_key = key.replace("model.", "") |
| | state_dict[new_key] = state_dict[key] |
| | del state_dict[key] |
| | if "tokenizer" in state_dict: |
| | state_dict.pop("tokenizer") |
| | if "lm_head.weight" in state_dict: |
| | state_dict.pop("lm_head.weight") |
| |
|
| | |
| | |
| | |
| |
|
| | text_encoder1.load_state_dict(state_dict, strict=True, assign=True) |
| |
|
| | if fp8_llm: |
| | org_dtype = text_encoder1.dtype |
| | logger.info(f"Moving and casting text encoder to {device} and torch.float8_e4m3fn") |
| | text_encoder1.to(device=device, dtype=torch.float8_e4m3fn) |
| |
|
| | |
| | def prepare_fp8(llama_model: LlamaModel, target_dtype): |
| | def forward_hook(module): |
| | def forward(hidden_states): |
| | input_dtype = hidden_states.dtype |
| | hidden_states = hidden_states.to(torch.float32) |
| | variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| | hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon) |
| | return module.weight.to(input_dtype) * hidden_states.to(input_dtype) |
| |
|
| | return forward |
| |
|
| | for module in llama_model.modules(): |
| | if module.__class__.__name__ in ["Embedding"]: |
| | |
| | module.to(target_dtype) |
| | if module.__class__.__name__ in ["LlamaRMSNorm"]: |
| | |
| | module.forward = forward_hook(module) |
| |
|
| | prepare_fp8(text_encoder1, org_dtype) |
| | else: |
| | text_encoder1.to(device) |
| |
|
| | text_encoder1.eval() |
| | return tokenizer1, text_encoder1 |
| |
|
| |
|
| | def load_text_encoder2(args) -> tuple[CLIPTokenizer, CLIPTextModel]: |
| | |
| | logger.info(f"Loading text encoder 2 tokenizer") |
| | tokenizer2 = CLIPTokenizer.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="tokenizer_2") |
| |
|
| | logger.info(f"Loading text encoder 2 from {args.text_encoder2}") |
| | if os.path.isdir(args.text_encoder2): |
| | |
| | text_encoder2 = CLIPTextModel.from_pretrained(args.text_encoder2, subfolder="text_encoder_2", torch_dtype=torch.float16) |
| | else: |
| | |
| | config = CLIPConfig(**CLIP_CONFIG) |
| | with init_empty_weights(): |
| | text_encoder2 = CLIPTextModel._from_config(config, torch_dtype=torch.float16) |
| |
|
| | state_dict = load_file(args.text_encoder2) |
| |
|
| | text_encoder2.load_state_dict(state_dict, strict=True, assign=True) |
| |
|
| | text_encoder2.eval() |
| | return tokenizer2, text_encoder2 |
| |
|
| |
|
| | |
| |
|
| | |
| |
|
| | |
| | FEATURE_EXTRACTOR_CONFIG = { |
| | "do_convert_rgb": None, |
| | "do_normalize": True, |
| | "do_rescale": True, |
| | "do_resize": True, |
| | "image_mean": [0.5, 0.5, 0.5], |
| | "image_processor_type": "SiglipImageProcessor", |
| | "image_std": [0.5, 0.5, 0.5], |
| | "processor_class": "SiglipProcessor", |
| | "resample": 3, |
| | "rescale_factor": 0.00392156862745098, |
| | "size": {"height": 384, "width": 384}, |
| | } |
| | IMAGE_ENCODER_CONFIG = { |
| | "_name_or_path": "/home/lvmin/.cache/huggingface/hub/models--black-forest-labs--FLUX.1-Redux-dev/snapshots/1282f955f706b5240161278f2ef261d2a29ad649/image_encoder", |
| | "architectures": ["SiglipVisionModel"], |
| | "attention_dropout": 0.0, |
| | "hidden_act": "gelu_pytorch_tanh", |
| | "hidden_size": 1152, |
| | "image_size": 384, |
| | "intermediate_size": 4304, |
| | "layer_norm_eps": 1e-06, |
| | "model_type": "siglip_vision_model", |
| | "num_attention_heads": 16, |
| | "num_channels": 3, |
| | "num_hidden_layers": 27, |
| | "patch_size": 14, |
| | "torch_dtype": "bfloat16", |
| | "transformers_version": "4.46.2", |
| | } |
| |
|
| |
|
| | def load_image_encoders(args): |
| | logger.info(f"Loading image encoder feature extractor") |
| | feature_extractor = SiglipImageProcessor(**FEATURE_EXTRACTOR_CONFIG) |
| |
|
| | |
| | logger.info(f"Loading image encoder from {args.image_encoder}") |
| | if os.path.isdir(args.image_encoder): |
| | |
| | image_encoder = SiglipVisionModel.from_pretrained(args.image_encoder, subfolder="image_encoder", torch_dtype=torch.float16) |
| | else: |
| | |
| | config = SiglipVisionConfig(**IMAGE_ENCODER_CONFIG) |
| | with init_empty_weights(): |
| | image_encoder = SiglipVisionModel._from_config(config, torch_dtype=torch.float16) |
| |
|
| | state_dict = load_file(args.image_encoder) |
| |
|
| | image_encoder.load_state_dict(state_dict, strict=True, assign=True) |
| |
|
| | image_encoder.eval() |
| | return feature_extractor, image_encoder |
| |
|
| |
|
| | |
| |
|