File size: 2,870 Bytes
ee7c543
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""
SHIVIK-Code Configuration

Extends LlamaConfig for SHIVIK-Code specific settings.
"""

from transformers import LlamaConfig


class ShivikCodeConfig(LlamaConfig):
    """
    Configuration class for SHIVIK-Code.
    
    Extends LlamaConfig with:
    - Extended context length defaults
    - Tool token configuration
    - FIM token configuration
    """
    
    model_type = "shivik_code"
    
    def __init__(
        self,
        vocab_size=128279,
        hidden_size=2048,
        intermediate_size=8192,
        num_hidden_layers=16,
        num_attention_heads=32,
        num_key_value_heads=8,
        hidden_act="silu",
        max_position_embeddings=32768,
        initializer_range=0.02,
        rms_norm_eps=1e-5,
        use_cache=True,
        pad_token_id=None,
        bos_token_id=128000,
        eos_token_id=128001,
        tie_word_embeddings=False,
        rope_theta=500000.0,
        rope_scaling=None,
        attention_bias=False,
        attention_dropout=0.0,
        mlp_bias=False,
        # SHIVIK-Code specific
        tool_call_start_id=128256,
        tool_call_end_id=128257,
        tool_result_start_id=128258,
        tool_result_end_id=128259,
        fim_prefix_id=128276,
        fim_suffix_id=128277,
        fim_middle_id=128278,
        **kwargs,
    ):
        # Set YaRN scaling by default
        if rope_scaling is None:
            rope_scaling = {
                "type": "yarn",
                "factor": 8.0,
                "original_max_position_embeddings": 4096,
            }
        
        super().__init__(
            vocab_size=vocab_size,
            hidden_size=hidden_size,
            intermediate_size=intermediate_size,
            num_hidden_layers=num_hidden_layers,
            num_attention_heads=num_attention_heads,
            num_key_value_heads=num_key_value_heads,
            hidden_act=hidden_act,
            max_position_embeddings=max_position_embeddings,
            initializer_range=initializer_range,
            rms_norm_eps=rms_norm_eps,
            use_cache=use_cache,
            pad_token_id=pad_token_id,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            tie_word_embeddings=tie_word_embeddings,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            attention_bias=attention_bias,
            attention_dropout=attention_dropout,
            mlp_bias=mlp_bias,
            **kwargs,
        )
        
        # Tool tokens
        self.tool_call_start_id = tool_call_start_id
        self.tool_call_end_id = tool_call_end_id
        self.tool_result_start_id = tool_result_start_id
        self.tool_result_end_id = tool_result_end_id
        
        # FIM tokens
        self.fim_prefix_id = fim_prefix_id
        self.fim_suffix_id = fim_suffix_id
        self.fim_middle_id = fim_middle_id