| import dataclasses |
| from typing import Optional |
|
|
| @dataclasses.dataclass |
| class LyraLLaMAParam: |
| num_heads: int = 40 |
| size_per_head: int = 128 |
| inter_size: int = 13824 |
| num_layers: int = 40 |
| vocab_size: int = 39424 |
| start_id: Optional[int] = 1 |
| end_id: Optional[int] = 2 |
| tensor_para_size: int = 1 |
| pipeline_para_size: int = 1 |
| remove_padding: bool = True |
| shared_contexts_ratio: float = 1.0 |
| layernorm_eps: float = 1e-6 |
| weights_data_type: str = "fp16" |
| rotary_embedding: int = 128 |
| use_gptj_residual: bool = False |
|
|
| def __post_init__(self): |
| if not 0.0 <= self.shared_contexts_ratio <= 1.0: |
| raise ValueError( |
| f'Got an invalid value of shared_context_ratio ' |
| f'{self.shared_contexts_ratio} - range: [0.0, 1.0]') |
|
|
| def asdict(self): |
| return dataclasses.asdict(self) |
|
|
|
|
| LYRA_LLAMA_PARAM = LyraLLaMAParam() |
| LIB_SO_PATH = '/app/LyraLLaMAPy/ftlib/libth_transformer_sm80_cu11.so' |