| | import torch |
| | from transformers import T5EncoderModel, T5Config |
| | from .sd_text_encoder import SDTextEncoder |
| |
|
| |
|
| |
|
| | class FluxTextEncoder2(T5EncoderModel): |
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.eval() |
| |
|
| | def forward(self, input_ids): |
| | outputs = super().forward(input_ids=input_ids) |
| | prompt_emb = outputs.last_hidden_state |
| | return prompt_emb |
| |
|
| | @staticmethod |
| | def state_dict_converter(): |
| | return FluxTextEncoder2StateDictConverter() |
| |
|
| |
|
| |
|
| | class FluxTextEncoder2StateDictConverter(): |
| | def __init__(self): |
| | pass |
| |
|
| | def from_diffusers(self, state_dict): |
| | state_dict_ = state_dict |
| | return state_dict_ |
| |
|
| | def from_civitai(self, state_dict): |
| | return self.from_diffusers(state_dict) |
| |
|