File size: 1,973 Bytes
7ebf906 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
from transformers import PretrainedConfig
from typing import Literal, Optional
from lit_gpt.config import Config
class DiffusionLlamaConfig(Config, PretrainedConfig):
model_type = "diff_llama_v2"
eos_token_id = 2,
pad_token_id = 0,
mask_token_id = 32000
def __init__(
self,
block_size: int = 4096,
vocab_size: int = 50254,
padding_multiple: int = 512,
padded_vocab_size: Optional[int] = None,
n_layer: int = 16,
n_head: int = 32,
n_embd: int = 4096,
rotary_percentage: float = 0.25,
parallel_residual: bool = True,
bias: bool = True,
n_query_groups: Optional[int] = None,
shared_attention_norm: bool = False,
_norm_class: Literal["LayerNorm", "RMSNorm", "FusedRMSNorm"] = "LayerNorm",
norm_eps: float = 1e-5,
_mlp_class: Literal["GptNeoxMLP", "LLaMAMLP"] = "GptNeoxMLP",
intermediate_size: Optional[int] = None,
condense_ratio: int = 1,
**kwargs,
):
Config.__init__(
self,
block_size=block_size,
vocab_size=vocab_size,
padding_multiple=padding_multiple,
padded_vocab_size=padded_vocab_size,
n_layer=n_layer,
n_head=n_head,
n_embd=n_embd,
rotary_percentage=rotary_percentage,
parallel_residual=parallel_residual,
bias=bias,
n_query_groups=n_query_groups,
shared_attention_norm=shared_attention_norm,
_norm_class=_norm_class,
norm_eps=norm_eps,
_mlp_class=_mlp_class,
intermediate_size=intermediate_size,
condense_ratio=condense_ratio
)
PretrainedConfig.__init__(self, **kwargs)
|