from typing import Optional from transformers import PretrainedConfig class LUARConfig(PretrainedConfig): model_type = "LUAR" def __init__(self, embedding_size: int = 512, use_memory_efficient_attention=False, q_bucket_size=512, k_bucket_size=1024, upstream_transformer_revision: Optional[str] = None, **kwargs, ): self.embedding_size = embedding_size self.use_memory_efficient_attention = use_memory_efficient_attention self.q_bucket_size = q_bucket_size self.k_bucket_size = k_bucket_size self.upstream_transformer_revision = upstream_transformer_revision super().__init__(**kwargs)