File size: 1,777 Bytes
2c26b1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from typing import Optional

from transformers import PretrainedConfig


class VeronicaConfig(PretrainedConfig):
    model_type = "veronica"

    def __init__(
        self,
        vocab_size: int = 50257,
        n_layer: int = 24,
        n_head: int = 12,
        n_embd: int = 768,
        mlp_mult: float = 4.0,
        num_funcs: int = 3,
        router_dim: Optional[int] = None,
        dropout: float = 0.0,
        use_channel_attention: bool = False,
        max_position_embeddings: int = 4096,
        layer_norm_epsilon: float = 1e-5,
        gradient_checkpointing: bool = False,
        # router aux-loss weight (entropy regularizer)
        router_aux_weight: float = 0.02,
        # temperatura del router (softmax(logits / tau))
        router_tau: float = 1.0,
        # RoPE theta (base for frequency computation)
        rope_theta: float = 10000.0,
        **kwargs,
    ):
        super().__init__(**kwargs)

        # Dimensioni base
        self.vocab_size = vocab_size
        self.n_layer = n_layer
        self.n_head = n_head
        self.n_embd = n_embd
        self.mlp_mult = mlp_mult
        self.num_funcs = num_funcs
        self.router_dim = router_dim
        self.dropout = dropout
        self.use_channel_attention = use_channel_attention
        self.max_position_embeddings = max_position_embeddings
        self.layer_norm_epsilon = layer_norm_epsilon
        self.gradient_checkpointing = gradient_checkpointing

        # HF standard field names
        self.num_hidden_layers = n_layer
        self.num_attention_heads = n_head
        self.hidden_size = n_embd

        # Router
        self.router_aux_weight = router_aux_weight
        self.router_tau = router_tau
        
        # RoPE
        self.rope_theta = rope_theta