Update configuration_trillion.py
Browse files
configuration_trillion.py
CHANGED
|
@@ -174,6 +174,7 @@ class TrillionConfig(PretrainedConfig):
|
|
| 174 |
rope_scaling=None,
|
| 175 |
global_attention_freq=4,
|
| 176 |
attn_temperature_tuning=True,
|
|
|
|
| 177 |
floor_scale=1.0,
|
| 178 |
attn_scale=1.0,
|
| 179 |
attention_bias=False,
|
|
@@ -183,6 +184,7 @@ class TrillionConfig(PretrainedConfig):
|
|
| 183 |
**kwargs,
|
| 184 |
):
|
| 185 |
self.attn_temperature_tuning = attn_temperature_tuning
|
|
|
|
| 186 |
self.attn_scale = attn_scale
|
| 187 |
self.floor_scale = floor_scale
|
| 188 |
self.global_attention_freq = global_attention_freq
|
|
@@ -215,6 +217,10 @@ class TrillionConfig(PretrainedConfig):
|
|
| 215 |
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
| 216 |
rope_config_validation(self)
|
| 217 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
super().__init__(
|
| 219 |
pad_token_id=pad_token_id,
|
| 220 |
bos_token_id=bos_token_id,
|
|
|
|
| 174 |
rope_scaling=None,
|
| 175 |
global_attention_freq=4,
|
| 176 |
attn_temperature_tuning=True,
|
| 177 |
+
sliding_window=4096,
|
| 178 |
floor_scale=1.0,
|
| 179 |
attn_scale=1.0,
|
| 180 |
attention_bias=False,
|
|
|
|
| 184 |
**kwargs,
|
| 185 |
):
|
| 186 |
self.attn_temperature_tuning = attn_temperature_tuning
|
| 187 |
+
self.sliding_window = sliding_window
|
| 188 |
self.attn_scale = attn_scale
|
| 189 |
self.floor_scale = floor_scale
|
| 190 |
self.global_attention_freq = global_attention_freq
|
|
|
|
| 217 |
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
| 218 |
rope_config_validation(self)
|
| 219 |
|
| 220 |
+
self.layer_types = [
|
| 221 |
+
"sliding_attention" if bool((i + 1) % self.global_attention_freq) else "full_attention" for i in range(self.num_hidden_layers)
|
| 222 |
+
]
|
| 223 |
+
|
| 224 |
super().__init__(
|
| 225 |
pad_token_id=pad_token_id,
|
| 226 |
bos_token_id=bos_token_id,
|