File size: 701 Bytes
c034681
da4b785
 
c034681
 
 
 
 
 
 
 
 
 
da4b785
c034681
 
 
 
 
 
c331155
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23

from typing import Optional

from transformers import PretrainedConfig

class LUARConfig(PretrainedConfig):
    model_type = "LUAR"
    
    def __init__(self,
        embedding_size: int = 512,
        use_memory_efficient_attention=False,
        q_bucket_size=512,
        k_bucket_size=1024,
        upstream_transformer_revision: Optional[str] = None,
        **kwargs,
    ):
        self.embedding_size = embedding_size
        self.use_memory_efficient_attention = use_memory_efficient_attention
        self.q_bucket_size = q_bucket_size
        self.k_bucket_size = k_bucket_size
        self.upstream_transformer_revision = upstream_transformer_revision
        super().__init__(**kwargs)