| import torch |
| from transformers import T5EncoderModel, T5Config |
| from .sd_text_encoder import SDTextEncoder |
|
|
|
|
|
|
| class FluxTextEncoder2(T5EncoderModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.eval() |
|
|
| def forward(self, input_ids): |
| outputs = super().forward(input_ids=input_ids) |
| prompt_emb = outputs.last_hidden_state |
| return prompt_emb |
|
|
| @staticmethod |
| def state_dict_converter(): |
| return FluxTextEncoder2StateDictConverter() |
|
|
|
|
|
|
| class FluxTextEncoder2StateDictConverter(): |
| def __init__(self): |
| pass |
|
|
| def from_diffusers(self, state_dict): |
| state_dict_ = state_dict |
| return state_dict_ |
|
|
| def from_civitai(self, state_dict): |
| return self.from_diffusers(state_dict) |
|
|