File size: 701 Bytes
c034681 da4b785 c034681 da4b785 c034681 c331155 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
from typing import Optional
from transformers import PretrainedConfig
class LUARConfig(PretrainedConfig):
model_type = "LUAR"
def __init__(self,
embedding_size: int = 512,
use_memory_efficient_attention=False,
q_bucket_size=512,
k_bucket_size=1024,
upstream_transformer_revision: Optional[str] = None,
**kwargs,
):
self.embedding_size = embedding_size
self.use_memory_efficient_attention = use_memory_efficient_attention
self.q_bucket_size = q_bucket_size
self.k_bucket_size = k_bucket_size
self.upstream_transformer_revision = upstream_transformer_revision
super().__init__(**kwargs)
|