|
|
import torch |
|
|
from transformers import T5EncoderModel, T5Config |
|
|
from .sd_text_encoder import SDTextEncoder |
|
|
|
|
|
|
|
|
|
|
|
class FluxTextEncoder2(T5EncoderModel): |
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.eval() |
|
|
|
|
|
def forward(self, input_ids): |
|
|
outputs = super().forward(input_ids=input_ids) |
|
|
prompt_emb = outputs.last_hidden_state |
|
|
return prompt_emb |
|
|
|
|
|
@staticmethod |
|
|
def state_dict_converter(): |
|
|
return FluxTextEncoder2StateDictConverter() |
|
|
|
|
|
|
|
|
|
|
|
class FluxTextEncoder2StateDictConverter(): |
|
|
def __init__(self): |
|
|
pass |
|
|
|
|
|
def from_diffusers(self, state_dict): |
|
|
state_dict_ = state_dict |
|
|
return state_dict_ |
|
|
|
|
|
def from_civitai(self, state_dict): |
|
|
return self.from_diffusers(state_dict) |
|
|
|