| from transformers import PretrainedConfig | |
| class ImuruConfig(PretrainedConfig): | |
| model_type = "emuru" | |
| def __init__(self, | |
| t5_name_or_path='google-t5/t5-large', | |
| vae_name_or_path='blowing-up-groundhogs/emuru_vae', | |
| tokenizer_name_or_path='google/byt5-small', | |
| slices_per_query=1, | |
| vae_channels=1, | |
| style_enc="mean", | |
| **kwargs): | |
| super().__init__(**kwargs) | |
| self.t5_name_or_path = t5_name_or_path | |
| self.vae_name_or_path = vae_name_or_path | |
| self.tokenizer_name_or_path = tokenizer_name_or_path | |
| self.slices_per_query = slices_per_query | |
| self.vae_channels = vae_channels | |
| self.style_enc = style_enc |