File size: 662 Bytes
feba2ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
"""
Model Config

Specifies the hyperparameters for the Pico model/model architecture.
"""

from dataclasses import dataclass
from typing import Optional

from ._constants import BATCH_SIZE, MAX_SEQ_LEN, VOCAB_SIZE


@dataclass
class ModelConfig:
    model_type: str = "pico_decoder"

    # Pico Decoder default hyperparameters

    d_model: int = 768
    n_layers: int = 12

    vocab_size: int = VOCAB_SIZE
    batch_size: int = BATCH_SIZE
    max_seq_len: int = MAX_SEQ_LEN

    attention_n_heads: int = 12
    attention_n_kv_heads: Optional[int] = 4

    activation_hidden_dim: int = 3072

    norm_eps: float = 1e-6

    position_emb_theta: float = 10000.0