File size: 3,879 Bytes
fd8d063
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from transformers import PretrainedConfig


class PhiReasoningConfig(PretrainedConfig):
    model_type = "phi_reasoning"

    def __init__(
        self,
        # Standard Phi
        vocab_size: int = 51200,
        hidden_size: int = 2560,
        intermediate_size: int = 10240,
        num_hidden_layers: int = 32,
        num_attention_heads: int = 32,
        num_key_value_heads: int = 32,
        head_dim: int = None,
        hidden_act: str = "gelu_new",
        max_position_embeddings: int = 2048,
        initializer_range: float = 0.02,
        layer_norm_eps: float = 1e-5,
        use_cache: bool = True,
        tie_word_embeddings: bool = False,
        rope_theta: float = 10000.0,
        rope_scaling: dict = None,
        partial_rotary_factor: float = 0.5,
        qk_layernorm: bool = False,
        bos_token_id: int = 1,
        eos_token_id: int = 2,
        pad_token_id: int = 0,
        attention_dropout: float = 0.0,
        embd_pdrop: float = 0.0,
        resid_pdrop: float = 0.0,
        # Reasoning Config
        num_reasoning_tokens: int = 32,
        num_reasoning_steps: int = 8,
        num_reasoning_layers: int = 4,
        max_reasoning_steps: int = 16,
        min_reasoning_steps: int = 2,
        reasoning_dropout: float = 0.1,
        reasoning_intermediate_size: int = 2560,
        use_adaptive_halting: bool = True,
        halting_threshold: float = 0.8,
        use_input_gating: bool = True,
        gating_threshold: float = 0.3,
        ponder_loss_weight: float = 0.01,
        consistency_loss_weight: float = 0.001,
        share_reasoning_layers: bool = False,
        reasoning_injection_point: int = None,
        **kwargs,
    ):
        super().__init__(
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            pad_token_id=pad_token_id,
            tie_word_embeddings=tie_word_embeddings,
            **kwargs,
        )

        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.num_key_value_heads = num_key_value_heads
        self.head_dim = head_dim or hidden_size // num_attention_heads
        self.hidden_act = hidden_act
        self.max_position_embeddings = max_position_embeddings
        self.initializer_range = initializer_range
        self.layer_norm_eps = layer_norm_eps
        self.use_cache = use_cache
        self.rope_theta = rope_theta
        self.rope_scaling = rope_scaling
        self.partial_rotary_factor = partial_rotary_factor
        self.qk_layernorm = qk_layernorm
        self.attention_dropout = attention_dropout
        self.embd_pdrop = embd_pdrop
        self.resid_pdrop = resid_pdrop

        self.num_reasoning_tokens = num_reasoning_tokens
        self.num_reasoning_steps = num_reasoning_steps
        self.num_reasoning_layers = num_reasoning_layers
        self.max_reasoning_steps = max_reasoning_steps
        self.min_reasoning_steps = min_reasoning_steps
        self.reasoning_dropout = reasoning_dropout
        self.reasoning_intermediate_size = reasoning_intermediate_size

        self.use_adaptive_halting = use_adaptive_halting
        self.halting_threshold = halting_threshold
        self.use_input_gating = use_input_gating
        self.gating_threshold = gating_threshold
        self.ponder_loss_weight = ponder_loss_weight
        self.consistency_loss_weight = consistency_loss_weight
        self.share_reasoning_layers = share_reasoning_layers
        self.reasoning_injection_point = reasoning_injection_point or (num_hidden_layers // 2)

        self.rope_parameters = {
            "rope_type": "default",
            "rope_theta": rope_theta,
            "partial_rotary_factor": partial_rotary_factor,
        }