| import os |
| import logging |
| from types import SimpleNamespace |
| from typing import Optional, Union |
|
|
| import accelerate |
| from accelerate import Accelerator, init_empty_weights |
| import torch |
| from safetensors.torch import load_file |
| from transformers import ( |
| LlamaTokenizerFast, |
| LlamaConfig, |
| LlamaModel, |
| CLIPTokenizer, |
| CLIPTextModel, |
| CLIPConfig, |
| SiglipImageProcessor, |
| SiglipVisionModel, |
| SiglipVisionConfig, |
| ) |
|
|
| from utils.safetensors_utils import load_split_weights |
| from hunyuan_model.vae import load_vae as hunyuan_load_vae |
|
|
| import logging |
|
|
| logger = logging.getLogger(__name__) |
| logging.basicConfig(level=logging.INFO) |
|
|
|
|
| def load_vae( |
| vae_path: str, vae_chunk_size: Optional[int], vae_spatial_tile_sample_min_size: Optional[int], device: Union[str, torch.device] |
| ): |
| |
| if os.path.isdir(vae_path): |
| vae_path = os.path.join(vae_path, "vae", "diffusion_pytorch_model.safetensors") |
| else: |
| vae_path = vae_path |
|
|
| vae_dtype = torch.float16 |
| vae, _, s_ratio, t_ratio = hunyuan_load_vae(vae_dtype=vae_dtype, device=device, vae_path=vae_path) |
| vae.eval() |
| |
|
|
| |
| chunk_size = vae_chunk_size |
| if chunk_size is not None: |
| vae.set_chunk_size_for_causal_conv_3d(chunk_size) |
| logger.info(f"Set chunk_size to {chunk_size} for CausalConv3d") |
|
|
| if vae_spatial_tile_sample_min_size is not None: |
| vae.enable_spatial_tiling(True) |
| vae.tile_sample_min_size = vae_spatial_tile_sample_min_size |
| vae.tile_latent_min_size = vae_spatial_tile_sample_min_size // 8 |
| logger.info(f"Enabled spatial tiling with min size {vae_spatial_tile_sample_min_size}") |
| |
| else: |
| vae.enable_spatial_tiling(True) |
|
|
| return vae |
|
|
|
|
| |
|
|
| |
|
|
| LLAMA_CONFIG = { |
| "architectures": ["LlamaModel"], |
| "attention_bias": False, |
| "attention_dropout": 0.0, |
| "bos_token_id": 128000, |
| "eos_token_id": 128001, |
| "head_dim": 128, |
| "hidden_act": "silu", |
| "hidden_size": 4096, |
| "initializer_range": 0.02, |
| "intermediate_size": 14336, |
| "max_position_embeddings": 8192, |
| "mlp_bias": False, |
| "model_type": "llama", |
| "num_attention_heads": 32, |
| "num_hidden_layers": 32, |
| "num_key_value_heads": 8, |
| "pretraining_tp": 1, |
| "rms_norm_eps": 1e-05, |
| "rope_scaling": None, |
| "rope_theta": 500000.0, |
| "tie_word_embeddings": False, |
| "torch_dtype": "float16", |
| "transformers_version": "4.46.3", |
| "use_cache": True, |
| "vocab_size": 128320, |
| } |
|
|
| CLIP_CONFIG = { |
| |
| "architectures": ["CLIPTextModel"], |
| "attention_dropout": 0.0, |
| "bos_token_id": 0, |
| "dropout": 0.0, |
| "eos_token_id": 2, |
| "hidden_act": "quick_gelu", |
| "hidden_size": 768, |
| "initializer_factor": 1.0, |
| "initializer_range": 0.02, |
| "intermediate_size": 3072, |
| "layer_norm_eps": 1e-05, |
| "max_position_embeddings": 77, |
| "model_type": "clip_text_model", |
| "num_attention_heads": 12, |
| "num_hidden_layers": 12, |
| "pad_token_id": 1, |
| "projection_dim": 768, |
| "torch_dtype": "float16", |
| "transformers_version": "4.48.0.dev0", |
| "vocab_size": 49408, |
| } |
|
|
|
|
| def load_text_encoder1( |
| args, fp8_llm: Optional[bool] = False, device: Optional[Union[str, torch.device]] = None |
| ) -> tuple[LlamaTokenizerFast, LlamaModel]: |
| |
| logger.info(f"Loading text encoder 1 tokenizer") |
| tokenizer1 = LlamaTokenizerFast.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="tokenizer") |
|
|
| logger.info(f"Loading text encoder 1 from {args.text_encoder1}") |
| if os.path.isdir(args.text_encoder1): |
| |
| text_encoder1 = LlamaModel.from_pretrained(args.text_encoder1, subfolder="text_encoder", torch_dtype=torch.float16) |
| else: |
| |
| config = LlamaConfig(**LLAMA_CONFIG) |
| with init_empty_weights(): |
| text_encoder1 = LlamaModel._from_config(config, torch_dtype=torch.float16) |
|
|
| state_dict = load_split_weights(args.text_encoder1) |
|
|
| |
| if "model.embed_tokens.weight" in state_dict: |
| for key in list(state_dict.keys()): |
| if key.startswith("model."): |
| new_key = key.replace("model.", "") |
| state_dict[new_key] = state_dict[key] |
| del state_dict[key] |
| if "tokenizer" in state_dict: |
| state_dict.pop("tokenizer") |
| if "lm_head.weight" in state_dict: |
| state_dict.pop("lm_head.weight") |
|
|
| |
| |
| |
|
|
| text_encoder1.load_state_dict(state_dict, strict=True, assign=True) |
|
|
| if fp8_llm: |
| org_dtype = text_encoder1.dtype |
| logger.info(f"Moving and casting text encoder to {device} and torch.float8_e4m3fn") |
| text_encoder1.to(device=device, dtype=torch.float8_e4m3fn) |
|
|
| |
| def prepare_fp8(llama_model: LlamaModel, target_dtype): |
| def forward_hook(module): |
| def forward(hidden_states): |
| input_dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(torch.float32) |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon) |
| return module.weight.to(input_dtype) * hidden_states.to(input_dtype) |
|
|
| return forward |
|
|
| for module in llama_model.modules(): |
| if module.__class__.__name__ in ["Embedding"]: |
| |
| module.to(target_dtype) |
| if module.__class__.__name__ in ["LlamaRMSNorm"]: |
| |
| module.forward = forward_hook(module) |
|
|
| prepare_fp8(text_encoder1, org_dtype) |
| else: |
| text_encoder1.to(device) |
|
|
| text_encoder1.eval() |
| return tokenizer1, text_encoder1 |
|
|
|
|
| def load_text_encoder2(args) -> tuple[CLIPTokenizer, CLIPTextModel]: |
| |
| logger.info(f"Loading text encoder 2 tokenizer") |
| tokenizer2 = CLIPTokenizer.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="tokenizer_2") |
|
|
| logger.info(f"Loading text encoder 2 from {args.text_encoder2}") |
| if os.path.isdir(args.text_encoder2): |
| |
| text_encoder2 = CLIPTextModel.from_pretrained(args.text_encoder2, subfolder="text_encoder_2", torch_dtype=torch.float16) |
| else: |
| |
| config = CLIPConfig(**CLIP_CONFIG) |
| with init_empty_weights(): |
| text_encoder2 = CLIPTextModel._from_config(config, torch_dtype=torch.float16) |
|
|
| state_dict = load_file(args.text_encoder2) |
|
|
| text_encoder2.load_state_dict(state_dict, strict=True, assign=True) |
|
|
| text_encoder2.eval() |
| return tokenizer2, text_encoder2 |
|
|
|
|
| |
|
|
| |
|
|
| |
| FEATURE_EXTRACTOR_CONFIG = { |
| "do_convert_rgb": None, |
| "do_normalize": True, |
| "do_rescale": True, |
| "do_resize": True, |
| "image_mean": [0.5, 0.5, 0.5], |
| "image_processor_type": "SiglipImageProcessor", |
| "image_std": [0.5, 0.5, 0.5], |
| "processor_class": "SiglipProcessor", |
| "resample": 3, |
| "rescale_factor": 0.00392156862745098, |
| "size": {"height": 384, "width": 384}, |
| } |
| IMAGE_ENCODER_CONFIG = { |
| "_name_or_path": "/home/lvmin/.cache/huggingface/hub/models--black-forest-labs--FLUX.1-Redux-dev/snapshots/1282f955f706b5240161278f2ef261d2a29ad649/image_encoder", |
| "architectures": ["SiglipVisionModel"], |
| "attention_dropout": 0.0, |
| "hidden_act": "gelu_pytorch_tanh", |
| "hidden_size": 1152, |
| "image_size": 384, |
| "intermediate_size": 4304, |
| "layer_norm_eps": 1e-06, |
| "model_type": "siglip_vision_model", |
| "num_attention_heads": 16, |
| "num_channels": 3, |
| "num_hidden_layers": 27, |
| "patch_size": 14, |
| "torch_dtype": "bfloat16", |
| "transformers_version": "4.46.2", |
| } |
|
|
|
|
| def load_image_encoders(args): |
| logger.info(f"Loading image encoder feature extractor") |
| feature_extractor = SiglipImageProcessor(**FEATURE_EXTRACTOR_CONFIG) |
|
|
| |
| logger.info(f"Loading image encoder from {args.image_encoder}") |
| if os.path.isdir(args.image_encoder): |
| |
| image_encoder = SiglipVisionModel.from_pretrained(args.image_encoder, subfolder="image_encoder", torch_dtype=torch.float16) |
| else: |
| |
| config = SiglipVisionConfig(**IMAGE_ENCODER_CONFIG) |
| with init_empty_weights(): |
| image_encoder = SiglipVisionModel._from_config(config, torch_dtype=torch.float16) |
|
|
| state_dict = load_file(args.image_encoder) |
|
|
| image_encoder.load_state_dict(state_dict, strict=True, assign=True) |
|
|
| image_encoder.eval() |
| return feature_extractor, image_encoder |
|
|
|
|
| |
|
|