Update configuration_steerling.py

#3
by AyaGL - opened
Files changed (1) hide show
  1. configuration_steerling.py +102 -0
configuration_steerling.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class SteerlingConfig(PretrainedConfig):
5
+ model_type = "steerling"
6
+
7
+ def __init__(
8
+ self,
9
+ vocab_size=100281,
10
+ interpretable=True,
11
+ n_layers=32,
12
+ n_head=32,
13
+ n_embd=4096,
14
+ n_kv_heads=4,
15
+ block_size=4096,
16
+ diff_block_size=64,
17
+ use_rms_norm=True,
18
+ norm_eps=1e-05,
19
+ norm_order="post",
20
+ use_qk_norm=True,
21
+ use_rope=True,
22
+ rope_base=500000.0,
23
+ rope_full_precision=True,
24
+ clip_qkv=10.0,
25
+ mlp_type="swiglu",
26
+ activation="gelu",
27
+ mlp_ratio=4,
28
+ intermediate_size=None,
29
+ use_bias=False,
30
+ weight_sharing=True,
31
+ mask_token_id=100280,
32
+ endofchunk_token_id=100279,
33
+ n_concepts=33732,
34
+ n_unknown_concepts=101196,
35
+ concept_dim=4096,
36
+ use_attention_known=False,
37
+ use_attention_unknown=False,
38
+ topk_known=16,
39
+ topk_known_features=32,
40
+ unknown_topk=128,
41
+ use_unknown=True,
42
+ apply_topk_to_unknown=True,
43
+ topk_on_logits=False,
44
+ factorize_unknown=True,
45
+ factorize_rank=256,
46
+ use_epsilon_correction=True,
47
+ concept_block_size=4096,
48
+ pad_multiple=16,
49
+ store_unknown_weights=False,
50
+ inject_layer=16,
51
+ inject_alpha=1.0,
52
+ **kwargs,
53
+ ):
54
+ self.interpretable = interpretable
55
+ self.n_layers = n_layers
56
+ self.n_head = n_head
57
+ self.n_embd = n_embd
58
+ self.n_kv_heads = n_kv_heads
59
+ self.block_size = block_size
60
+ self.diff_block_size = diff_block_size
61
+ self.use_rms_norm = use_rms_norm
62
+ self.norm_eps = norm_eps
63
+ self.norm_order = norm_order
64
+ self.use_qk_norm = use_qk_norm
65
+ self.use_rope = use_rope
66
+ self.rope_base = rope_base
67
+ self.rope_full_precision = rope_full_precision
68
+ self.clip_qkv = clip_qkv
69
+ self.mlp_type = mlp_type
70
+ self.activation = activation
71
+ self.mlp_ratio = mlp_ratio
72
+ self.intermediate_size = intermediate_size
73
+ self.use_bias = use_bias
74
+ self.weight_sharing = weight_sharing
75
+ self.mask_token_id = mask_token_id
76
+ self.endofchunk_token_id = endofchunk_token_id
77
+ self.n_concepts = n_concepts
78
+ self.n_unknown_concepts = n_unknown_concepts
79
+ self.concept_dim = concept_dim
80
+ self.use_attention_known = use_attention_known
81
+ self.use_attention_unknown = use_attention_unknown
82
+ self.topk_known = topk_known
83
+ self.topk_known_features = topk_known_features
84
+ self.unknown_topk = unknown_topk
85
+ self.use_unknown = use_unknown
86
+ self.apply_topk_to_unknown = apply_topk_to_unknown
87
+ self.topk_on_logits = topk_on_logits
88
+ self.factorize_unknown = factorize_unknown
89
+ self.factorize_rank = factorize_rank
90
+ self.use_epsilon_correction = use_epsilon_correction
91
+ self.concept_block_size = concept_block_size
92
+ self.pad_multiple = pad_multiple
93
+ self.store_unknown_weights = store_unknown_weights
94
+ self.inject_layer = inject_layer
95
+ self.inject_alpha = inject_alpha
96
+ super().__init__(
97
+ vocab_size=vocab_size,
98
+ pad_token_id=kwargs.pop("pad_token_id", 100277),
99
+ bos_token_id=kwargs.pop("bos_token_id", 100278),
100
+ eos_token_id=kwargs.pop("eos_token_id", 100257),
101
+ **kwargs,
102
+ )