| | from transformers import BertModel, BertConfig, T5EncoderModel, T5Config |
| | import torch |
| |
|
| |
|
| |
|
| | class HunyuanDiTCLIPTextEncoder(BertModel): |
| | def __init__(self): |
| | config = BertConfig( |
| | _name_or_path = "", |
| | architectures = ["BertModel"], |
| | attention_probs_dropout_prob = 0.1, |
| | bos_token_id = 0, |
| | classifier_dropout = None, |
| | directionality = "bidi", |
| | eos_token_id = 2, |
| | hidden_act = "gelu", |
| | hidden_dropout_prob = 0.1, |
| | hidden_size = 1024, |
| | initializer_range = 0.02, |
| | intermediate_size = 4096, |
| | layer_norm_eps = 1e-12, |
| | max_position_embeddings = 512, |
| | model_type = "bert", |
| | num_attention_heads = 16, |
| | num_hidden_layers = 24, |
| | output_past = True, |
| | pad_token_id = 0, |
| | pooler_fc_size = 768, |
| | pooler_num_attention_heads = 12, |
| | pooler_num_fc_layers = 3, |
| | pooler_size_per_head = 128, |
| | pooler_type = "first_token_transform", |
| | position_embedding_type = "absolute", |
| | torch_dtype = "float32", |
| | transformers_version = "4.37.2", |
| | type_vocab_size = 2, |
| | use_cache = True, |
| | vocab_size = 47020 |
| | ) |
| | super().__init__(config, add_pooling_layer=False) |
| | self.eval() |
| |
|
| | def forward(self, input_ids, attention_mask, clip_skip=1): |
| | input_shape = input_ids.size() |
| |
|
| | batch_size, seq_length = input_shape |
| | device = input_ids.device |
| |
|
| | past_key_values_length = 0 |
| |
|
| | if attention_mask is None: |
| | attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) |
| |
|
| | extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) |
| |
|
| | embedding_output = self.embeddings( |
| | input_ids=input_ids, |
| | position_ids=None, |
| | token_type_ids=None, |
| | inputs_embeds=None, |
| | past_key_values_length=0, |
| | ) |
| | encoder_outputs = self.encoder( |
| | embedding_output, |
| | attention_mask=extended_attention_mask, |
| | head_mask=None, |
| | encoder_hidden_states=None, |
| | encoder_attention_mask=None, |
| | past_key_values=None, |
| | use_cache=False, |
| | output_attentions=False, |
| | output_hidden_states=True, |
| | return_dict=True, |
| | ) |
| | all_hidden_states = encoder_outputs.hidden_states |
| | prompt_emb = all_hidden_states[-clip_skip] |
| | if clip_skip > 1: |
| | mean, std = all_hidden_states[-1].mean(), all_hidden_states[-1].std() |
| | prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean |
| | return prompt_emb |
| |
|
| | @staticmethod |
| | def state_dict_converter(): |
| | return HunyuanDiTCLIPTextEncoderStateDictConverter() |
| |
|
| |
|
| |
|
| | class HunyuanDiTT5TextEncoder(T5EncoderModel): |
| | def __init__(self): |
| | config = T5Config( |
| | _name_or_path = "../HunyuanDiT/t2i/mt5", |
| | architectures = ["MT5ForConditionalGeneration"], |
| | classifier_dropout = 0.0, |
| | d_ff = 5120, |
| | d_kv = 64, |
| | d_model = 2048, |
| | decoder_start_token_id = 0, |
| | dense_act_fn = "gelu_new", |
| | dropout_rate = 0.1, |
| | eos_token_id = 1, |
| | feed_forward_proj = "gated-gelu", |
| | initializer_factor = 1.0, |
| | is_encoder_decoder = True, |
| | is_gated_act = True, |
| | layer_norm_epsilon = 1e-06, |
| | model_type = "t5", |
| | num_decoder_layers = 24, |
| | num_heads = 32, |
| | num_layers = 24, |
| | output_past = True, |
| | pad_token_id = 0, |
| | relative_attention_max_distance = 128, |
| | relative_attention_num_buckets = 32, |
| | tie_word_embeddings = False, |
| | tokenizer_class = "T5Tokenizer", |
| | transformers_version = "4.37.2", |
| | use_cache = True, |
| | vocab_size = 250112 |
| | ) |
| | super().__init__(config) |
| | self.eval() |
| |
|
| | def forward(self, input_ids, attention_mask, clip_skip=1): |
| | outputs = super().forward( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | output_hidden_states=True, |
| | ) |
| | prompt_emb = outputs.hidden_states[-clip_skip] |
| | if clip_skip > 1: |
| | mean, std = outputs.hidden_states[-1].mean(), outputs.hidden_states[-1].std() |
| | prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean |
| | return prompt_emb |
| | |
| | @staticmethod |
| | def state_dict_converter(): |
| | return HunyuanDiTT5TextEncoderStateDictConverter() |
| |
|
| |
|
| |
|
| | class HunyuanDiTCLIPTextEncoderStateDictConverter(): |
| | def __init__(self): |
| | pass |
| |
|
| | def from_diffusers(self, state_dict): |
| | state_dict_ = {name[5:]: param for name, param in state_dict.items() if name.startswith("bert.")} |
| | return state_dict_ |
| | |
| | def from_civitai(self, state_dict): |
| | return self.from_diffusers(state_dict) |
| |
|
| |
|
| | class HunyuanDiTT5TextEncoderStateDictConverter(): |
| | def __init__(self): |
| | pass |
| |
|
| | def from_diffusers(self, state_dict): |
| | state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("encoder.")} |
| | state_dict_["shared.weight"] = state_dict["shared.weight"] |
| | return state_dict_ |
| | |
| | def from_civitai(self, state_dict): |
| | return self.from_diffusers(state_dict) |
| |
|