syx
commited on
Commit
·
3909b06
1
Parent(s):
f94ad98
minor
Browse files
README.md
CHANGED
|
@@ -4,8 +4,8 @@ language:
|
|
| 4 |
- en
|
| 5 |
---
|
| 6 |
|
| 7 |
-
# Model Card for
|
| 8 |
-
The
|
| 9 |
|
| 10 |
<img src="takeaway.png" alt="avatar" width="300" height="200"/>
|
| 11 |
|
|
@@ -13,7 +13,7 @@ The average performance is evaluated using benchmarks from the OpenLLM Leaderboa
|
|
| 13 |
|
| 14 |
## Inference
|
| 15 |
|
| 16 |
-
Our code for accelerating
|
| 17 |
|
| 18 |
## Chat-Template
|
| 19 |
|
|
@@ -25,7 +25,7 @@ We take ChatML as our chat template:
|
|
| 25 |
|
| 26 |
## Allow Finetuning
|
| 27 |
|
| 28 |
-
As we merged the predictors for FFN neurons in models, you can finetune
|
| 29 |
|
| 30 |
## License
|
| 31 |
|
|
|
|
| 4 |
- en
|
| 5 |
---
|
| 6 |
|
| 7 |
+
# Model Card for TurboSparse-Mixtral
|
| 8 |
+
The TurboSparse-Mixtral Large Language Model (LLM) is an sparsified version of the Mixtral.
|
| 9 |
|
| 10 |
<img src="takeaway.png" alt="avatar" width="300" height="200"/>
|
| 11 |
|
|
|
|
| 13 |
|
| 14 |
## Inference
|
| 15 |
|
| 16 |
+
Our code for accelerating TurboSparse-Mixtral is currently being refined. Stay tuned! Now you can run this model like dense model.
|
| 17 |
|
| 18 |
## Chat-Template
|
| 19 |
|
|
|
|
| 25 |
|
| 26 |
## Allow Finetuning
|
| 27 |
|
| 28 |
+
As we merged the predictors for FFN neurons in models, you can finetune TurboSparse-Mixtral with any framework and algorithm.
|
| 29 |
|
| 30 |
## License
|
| 31 |
|
config.json
CHANGED
|
@@ -3,9 +3,9 @@
|
|
| 3 |
"TurboSparseMixtralForCausalLM"
|
| 4 |
],
|
| 5 |
"auto_map": {
|
| 6 |
-
"AutoConfig": "
|
| 7 |
-
"AutoModel": "
|
| 8 |
-
"AutoModelForCausalLM": "
|
| 9 |
},
|
| 10 |
"attention_dropout": 0.0,
|
| 11 |
"bos_token_id": 1,
|
|
@@ -15,7 +15,7 @@
|
|
| 15 |
"initializer_range": 0.02,
|
| 16 |
"intermediate_size": 14336,
|
| 17 |
"max_position_embeddings": 32768,
|
| 18 |
-
"model_type": "
|
| 19 |
"num_attention_heads": 32,
|
| 20 |
"num_experts_per_tok": 2,
|
| 21 |
"num_hidden_layers": 32,
|
|
|
|
| 3 |
"TurboSparseMixtralForCausalLM"
|
| 4 |
],
|
| 5 |
"auto_map": {
|
| 6 |
+
"AutoConfig": "configuration_turbosparsemixtral.TurboSparseMixtralConfig",
|
| 7 |
+
"AutoModel": "modeling_turbosparsemixtral.TurboSparseMixtralForCausalLM",
|
| 8 |
+
"AutoModelForCausalLM": "modeling_turbosparsemixtral.TurboSparseMixtralForCausalLM"
|
| 9 |
},
|
| 10 |
"attention_dropout": 0.0,
|
| 11 |
"bos_token_id": 1,
|
|
|
|
| 15 |
"initializer_range": 0.02,
|
| 16 |
"intermediate_size": 14336,
|
| 17 |
"max_position_embeddings": 32768,
|
| 18 |
+
"model_type": "turbosparsemixtral",
|
| 19 |
"num_attention_heads": 32,
|
| 20 |
"num_experts_per_tok": 2,
|
| 21 |
"num_hidden_layers": 32,
|
configuration_supersparsemixtral.py → configuration_turbosparsemixtral.py
RENAMED
|
@@ -22,7 +22,7 @@ from transformers.utils import logging
|
|
| 22 |
|
| 23 |
logger = logging.get_logger(__name__)
|
| 24 |
|
| 25 |
-
class
|
| 26 |
r"""
|
| 27 |
This is the configuration class to store the configuration of a [`MixtralModel`]. It is used to instantiate an
|
| 28 |
Mixtral model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
|
@@ -106,7 +106,7 @@ class SuperSparseMixtralConfig(PretrainedConfig):
|
|
| 106 |
>>> configuration = model.config
|
| 107 |
```"""
|
| 108 |
|
| 109 |
-
model_type = "
|
| 110 |
keys_to_ignore_at_inference = ["past_key_values"]
|
| 111 |
|
| 112 |
def __init__(
|
|
|
|
| 22 |
|
| 23 |
logger = logging.get_logger(__name__)
|
| 24 |
|
| 25 |
+
class TurboSparseMixtralConfig(PretrainedConfig):
|
| 26 |
r"""
|
| 27 |
This is the configuration class to store the configuration of a [`MixtralModel`]. It is used to instantiate an
|
| 28 |
Mixtral model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
|
|
|
| 106 |
>>> configuration = model.config
|
| 107 |
```"""
|
| 108 |
|
| 109 |
+
model_type = "turbosparsemixtral"
|
| 110 |
keys_to_ignore_at_inference = ["past_key_values"]
|
| 111 |
|
| 112 |
def __init__(
|
modeling_supersparsemixtral.py → modeling_turbosparsemixtral.py
RENAMED
|
@@ -54,7 +54,7 @@ from transformers.utils import (
|
|
| 54 |
replace_return_docstrings,
|
| 55 |
is_torch_fx_available,
|
| 56 |
)
|
| 57 |
-
from .
|
| 58 |
@dataclass
|
| 59 |
class AttentionMaskConverter:
|
| 60 |
"""
|
|
@@ -634,7 +634,7 @@ def _get_unpad_data(attention_mask):
|
|
| 634 |
|
| 635 |
|
| 636 |
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral
|
| 637 |
-
class
|
| 638 |
def __init__(self, hidden_size, eps=1e-6):
|
| 639 |
"""
|
| 640 |
MixtralRMSNorm is equivalent to T5LayerNorm
|
|
@@ -653,7 +653,7 @@ class SuperSparseMixtralRMSNorm(nn.Module):
|
|
| 653 |
|
| 654 |
# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral
|
| 655 |
# TODO @longjie no longer copied from Mistral after static cache
|
| 656 |
-
class
|
| 657 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| 658 |
super().__init__()
|
| 659 |
|
|
@@ -742,13 +742,13 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
| 742 |
|
| 743 |
# copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral
|
| 744 |
# TODO @longjie no longer copied from Mistral after static cache
|
| 745 |
-
class
|
| 746 |
"""
|
| 747 |
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
| 748 |
and "Generating Long Sequences with Sparse Transformers".
|
| 749 |
"""
|
| 750 |
|
| 751 |
-
def __init__(self, config:
|
| 752 |
super().__init__()
|
| 753 |
self.config = config
|
| 754 |
self.layer_idx = layer_idx
|
|
@@ -779,7 +779,7 @@ class SuperSparseMixtralAttention(nn.Module):
|
|
| 779 |
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
| 780 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
| 781 |
|
| 782 |
-
self.rotary_emb =
|
| 783 |
self.head_dim,
|
| 784 |
max_position_embeddings=self.max_position_embeddings,
|
| 785 |
base=self.rope_theta,
|
|
@@ -867,7 +867,7 @@ class SuperSparseMixtralAttention(nn.Module):
|
|
| 867 |
|
| 868 |
# copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral
|
| 869 |
# TODO @longjie no longer copied from Mistral after static cache
|
| 870 |
-
class
|
| 871 |
"""
|
| 872 |
Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays
|
| 873 |
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
|
@@ -1154,7 +1154,7 @@ class SuperSparseMixtralFlashAttention2(SuperSparseMixtralAttention):
|
|
| 1154 |
|
| 1155 |
# copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral
|
| 1156 |
# TODO @longjie no longer copied from Mistral after static cache
|
| 1157 |
-
class
|
| 1158 |
"""
|
| 1159 |
Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
| 1160 |
`MixtralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
|
@@ -1246,9 +1246,9 @@ class SuperSparseMixtralSdpaAttention(SuperSparseMixtralAttention):
|
|
| 1246 |
|
| 1247 |
|
| 1248 |
MIXTRAL_ATTENTION_CLASSES = {
|
| 1249 |
-
"eager":
|
| 1250 |
-
"flash_attention_2":
|
| 1251 |
-
"sdpa":
|
| 1252 |
}
|
| 1253 |
|
| 1254 |
class MLP(nn.Module):
|
|
@@ -1264,8 +1264,8 @@ class MLP(nn.Module):
|
|
| 1264 |
x = self.fc2(x)
|
| 1265 |
x = x.sigmoid()
|
| 1266 |
return x
|
| 1267 |
-
class
|
| 1268 |
-
def __init__(self, config:
|
| 1269 |
super().__init__()
|
| 1270 |
self.ffn_dim = config.intermediate_size
|
| 1271 |
self.hidden_dim = config.hidden_size
|
|
@@ -1288,7 +1288,7 @@ class SuperSparseMixtralBlockSparseTop2MLP(nn.Module):
|
|
| 1288 |
return current_hidden_states
|
| 1289 |
|
| 1290 |
|
| 1291 |
-
class
|
| 1292 |
"""
|
| 1293 |
This implementation is
|
| 1294 |
strictly equivalent to standard MoE with full capacity (no
|
|
@@ -1310,7 +1310,7 @@ class SuperSparseMixtralSparseMoeBlock(nn.Module):
|
|
| 1310 |
# gating
|
| 1311 |
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
| 1312 |
|
| 1313 |
-
self.experts = nn.ModuleList([
|
| 1314 |
|
| 1315 |
# Jitter parameters
|
| 1316 |
self.jitter_noise = config.router_jitter_noise
|
|
@@ -1356,16 +1356,16 @@ class SuperSparseMixtralSparseMoeBlock(nn.Module):
|
|
| 1356 |
return final_hidden_states, router_logits
|
| 1357 |
|
| 1358 |
|
| 1359 |
-
class
|
| 1360 |
-
def __init__(self, config:
|
| 1361 |
super().__init__()
|
| 1362 |
self.hidden_size = config.hidden_size
|
| 1363 |
|
| 1364 |
self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
| 1365 |
|
| 1366 |
-
self.block_sparse_moe =
|
| 1367 |
-
self.input_layernorm =
|
| 1368 |
-
self.post_attention_layernorm =
|
| 1369 |
|
| 1370 |
def forward(
|
| 1371 |
self,
|
|
@@ -1451,11 +1451,11 @@ MIXTRAL_START_DOCSTRING = r"""
|
|
| 1451 |
MIXTRAL_START_DOCSTRING,
|
| 1452 |
)
|
| 1453 |
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2PreTrainedModel with Qwen2->Mixtral
|
| 1454 |
-
class
|
| 1455 |
-
config_class =
|
| 1456 |
base_model_prefix = "model"
|
| 1457 |
supports_gradient_checkpointing = True
|
| 1458 |
-
_no_split_modules = ["
|
| 1459 |
_skip_keys_device_placement = "past_key_values"
|
| 1460 |
_supports_flash_attn_2 = True
|
| 1461 |
_supports_sdpa = True
|
|
@@ -1546,7 +1546,7 @@ MIXTRAL_INPUTS_DOCSTRING = r"""
|
|
| 1546 |
)
|
| 1547 |
# copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral
|
| 1548 |
# TODO @longjie no longer copied from Mistral after static cache
|
| 1549 |
-
class
|
| 1550 |
"""
|
| 1551 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`]
|
| 1552 |
|
|
@@ -1554,17 +1554,17 @@ class SuperSparseMixtralModel(SuperSparseMixtralPreTrainedModel):
|
|
| 1554 |
config: MixtralConfig
|
| 1555 |
"""
|
| 1556 |
|
| 1557 |
-
def __init__(self, config:
|
| 1558 |
super().__init__(config)
|
| 1559 |
self.padding_idx = config.pad_token_id
|
| 1560 |
self.vocab_size = config.vocab_size
|
| 1561 |
|
| 1562 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 1563 |
self.layers = nn.ModuleList(
|
| 1564 |
-
[
|
| 1565 |
)
|
| 1566 |
self._attn_implementation = config._attn_implementation
|
| 1567 |
-
self.norm =
|
| 1568 |
|
| 1569 |
self.gradient_checkpointing = False
|
| 1570 |
# Initialize weights and apply final processing
|
|
@@ -1741,12 +1741,12 @@ class SuperSparseMixtralModel(SuperSparseMixtralPreTrainedModel):
|
|
| 1741 |
)
|
| 1742 |
|
| 1743 |
|
| 1744 |
-
class
|
| 1745 |
_tied_weights_keys = ["lm_head.weight"]
|
| 1746 |
|
| 1747 |
def __init__(self, config):
|
| 1748 |
super().__init__(config)
|
| 1749 |
-
self.model =
|
| 1750 |
self.vocab_size = config.vocab_size
|
| 1751 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1752 |
self.router_aux_loss_coef = config.router_aux_loss_coef
|
|
@@ -1974,11 +1974,11 @@ class SuperSparseMixtralForCausalLM(SuperSparseMixtralPreTrainedModel):
|
|
| 1974 |
MIXTRAL_START_DOCSTRING,
|
| 1975 |
)
|
| 1976 |
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mixtral, LLAMA->MIXTRAL
|
| 1977 |
-
class
|
| 1978 |
def __init__(self, config):
|
| 1979 |
super().__init__(config)
|
| 1980 |
self.num_labels = config.num_labels
|
| 1981 |
-
self.model =
|
| 1982 |
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
| 1983 |
|
| 1984 |
# Initialize weights and apply final processing
|
|
@@ -2090,11 +2090,11 @@ class SuperSparseMixtralForSequenceClassification(SuperSparseMixtralPreTrainedMo
|
|
| 2090 |
MIXTRAL_START_DOCSTRING,
|
| 2091 |
)
|
| 2092 |
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mixtral, LLAMA->MIXTRAL
|
| 2093 |
-
class
|
| 2094 |
def __init__(self, config):
|
| 2095 |
super().__init__(config)
|
| 2096 |
self.num_labels = config.num_labels
|
| 2097 |
-
self.model =
|
| 2098 |
if getattr(config, "classifier_dropout", None) is not None:
|
| 2099 |
classifier_dropout = config.classifier_dropout
|
| 2100 |
elif getattr(config, "hidden_dropout", None) is not None:
|
|
|
|
| 54 |
replace_return_docstrings,
|
| 55 |
is_torch_fx_available,
|
| 56 |
)
|
| 57 |
+
from .configuration_turbosparsemixtral import TurboSparseMixtralConfig
|
| 58 |
@dataclass
|
| 59 |
class AttentionMaskConverter:
|
| 60 |
"""
|
|
|
|
| 634 |
|
| 635 |
|
| 636 |
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral
|
| 637 |
+
class TurboSparseMixtralRMSNorm(nn.Module):
|
| 638 |
def __init__(self, hidden_size, eps=1e-6):
|
| 639 |
"""
|
| 640 |
MixtralRMSNorm is equivalent to T5LayerNorm
|
|
|
|
| 653 |
|
| 654 |
# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral
|
| 655 |
# TODO @longjie no longer copied from Mistral after static cache
|
| 656 |
+
class TurboSparseMixtralRotaryEmbedding(nn.Module):
|
| 657 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| 658 |
super().__init__()
|
| 659 |
|
|
|
|
| 742 |
|
| 743 |
# copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral
|
| 744 |
# TODO @longjie no longer copied from Mistral after static cache
|
| 745 |
+
class TurboSparseMixtralAttention(nn.Module):
|
| 746 |
"""
|
| 747 |
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
| 748 |
and "Generating Long Sequences with Sparse Transformers".
|
| 749 |
"""
|
| 750 |
|
| 751 |
+
def __init__(self, config: TurboSparseMixtralConfig, layer_idx: Optional[int] = None):
|
| 752 |
super().__init__()
|
| 753 |
self.config = config
|
| 754 |
self.layer_idx = layer_idx
|
|
|
|
| 779 |
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
| 780 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
| 781 |
|
| 782 |
+
self.rotary_emb = TurboSparseMixtralRotaryEmbedding(
|
| 783 |
self.head_dim,
|
| 784 |
max_position_embeddings=self.max_position_embeddings,
|
| 785 |
base=self.rope_theta,
|
|
|
|
| 867 |
|
| 868 |
# copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral
|
| 869 |
# TODO @longjie no longer copied from Mistral after static cache
|
| 870 |
+
class TurboSparseMixtralFlashAttention2(TurboSparseMixtralAttention):
|
| 871 |
"""
|
| 872 |
Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays
|
| 873 |
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
|
|
|
| 1154 |
|
| 1155 |
# copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral
|
| 1156 |
# TODO @longjie no longer copied from Mistral after static cache
|
| 1157 |
+
class TurboSparseMixtralSdpaAttention(TurboSparseMixtralAttention):
|
| 1158 |
"""
|
| 1159 |
Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
| 1160 |
`MixtralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
|
|
|
| 1246 |
|
| 1247 |
|
| 1248 |
MIXTRAL_ATTENTION_CLASSES = {
|
| 1249 |
+
"eager": TurboSparseMixtralAttention,
|
| 1250 |
+
"flash_attention_2": TurboSparseMixtralFlashAttention2,
|
| 1251 |
+
"sdpa": TurboSparseMixtralSdpaAttention,
|
| 1252 |
}
|
| 1253 |
|
| 1254 |
class MLP(nn.Module):
|
|
|
|
| 1264 |
x = self.fc2(x)
|
| 1265 |
x = x.sigmoid()
|
| 1266 |
return x
|
| 1267 |
+
class TurboSparseMixtralBlockSparseTop2MLP(nn.Module):
|
| 1268 |
+
def __init__(self, config: TurboSparseMixtralConfig, layer_id):
|
| 1269 |
super().__init__()
|
| 1270 |
self.ffn_dim = config.intermediate_size
|
| 1271 |
self.hidden_dim = config.hidden_size
|
|
|
|
| 1288 |
return current_hidden_states
|
| 1289 |
|
| 1290 |
|
| 1291 |
+
class TurboSparseMixtralSparseMoeBlock(nn.Module):
|
| 1292 |
"""
|
| 1293 |
This implementation is
|
| 1294 |
strictly equivalent to standard MoE with full capacity (no
|
|
|
|
| 1310 |
# gating
|
| 1311 |
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
| 1312 |
|
| 1313 |
+
self.experts = nn.ModuleList([TurboSparseMixtralBlockSparseTop2MLP(config, layer_id) for _ in range(self.num_experts)])
|
| 1314 |
|
| 1315 |
# Jitter parameters
|
| 1316 |
self.jitter_noise = config.router_jitter_noise
|
|
|
|
| 1356 |
return final_hidden_states, router_logits
|
| 1357 |
|
| 1358 |
|
| 1359 |
+
class TurboSparseMixtralDecoderLayer(nn.Module):
|
| 1360 |
+
def __init__(self, config: TurboSparseMixtralConfig, layer_idx: int):
|
| 1361 |
super().__init__()
|
| 1362 |
self.hidden_size = config.hidden_size
|
| 1363 |
|
| 1364 |
self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
| 1365 |
|
| 1366 |
+
self.block_sparse_moe = TurboSparseMixtralSparseMoeBlock(config, layer_idx)
|
| 1367 |
+
self.input_layernorm = TurboSparseMixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 1368 |
+
self.post_attention_layernorm = TurboSparseMixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 1369 |
|
| 1370 |
def forward(
|
| 1371 |
self,
|
|
|
|
| 1451 |
MIXTRAL_START_DOCSTRING,
|
| 1452 |
)
|
| 1453 |
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2PreTrainedModel with Qwen2->Mixtral
|
| 1454 |
+
class TurboSparseMixtralPreTrainedModel(PreTrainedModel):
|
| 1455 |
+
config_class = TurboSparseMixtralConfig
|
| 1456 |
base_model_prefix = "model"
|
| 1457 |
supports_gradient_checkpointing = True
|
| 1458 |
+
_no_split_modules = ["TurboSparseMixtralDecoderLayer"]
|
| 1459 |
_skip_keys_device_placement = "past_key_values"
|
| 1460 |
_supports_flash_attn_2 = True
|
| 1461 |
_supports_sdpa = True
|
|
|
|
| 1546 |
)
|
| 1547 |
# copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral
|
| 1548 |
# TODO @longjie no longer copied from Mistral after static cache
|
| 1549 |
+
class TurboSparseMixtralModel(TurboSparseMixtralPreTrainedModel):
|
| 1550 |
"""
|
| 1551 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`]
|
| 1552 |
|
|
|
|
| 1554 |
config: MixtralConfig
|
| 1555 |
"""
|
| 1556 |
|
| 1557 |
+
def __init__(self, config: TurboSparseMixtralConfig):
|
| 1558 |
super().__init__(config)
|
| 1559 |
self.padding_idx = config.pad_token_id
|
| 1560 |
self.vocab_size = config.vocab_size
|
| 1561 |
|
| 1562 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 1563 |
self.layers = nn.ModuleList(
|
| 1564 |
+
[TurboSparseMixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 1565 |
)
|
| 1566 |
self._attn_implementation = config._attn_implementation
|
| 1567 |
+
self.norm = TurboSparseMixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 1568 |
|
| 1569 |
self.gradient_checkpointing = False
|
| 1570 |
# Initialize weights and apply final processing
|
|
|
|
| 1741 |
)
|
| 1742 |
|
| 1743 |
|
| 1744 |
+
class TurboSparseMixtralForCausalLM(TurboSparseMixtralPreTrainedModel):
|
| 1745 |
_tied_weights_keys = ["lm_head.weight"]
|
| 1746 |
|
| 1747 |
def __init__(self, config):
|
| 1748 |
super().__init__(config)
|
| 1749 |
+
self.model = TurboSparseMixtralModel(config)
|
| 1750 |
self.vocab_size = config.vocab_size
|
| 1751 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1752 |
self.router_aux_loss_coef = config.router_aux_loss_coef
|
|
|
|
| 1974 |
MIXTRAL_START_DOCSTRING,
|
| 1975 |
)
|
| 1976 |
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mixtral, LLAMA->MIXTRAL
|
| 1977 |
+
class TurboSparseMixtralForSequenceClassification(TurboSparseMixtralPreTrainedModel):
|
| 1978 |
def __init__(self, config):
|
| 1979 |
super().__init__(config)
|
| 1980 |
self.num_labels = config.num_labels
|
| 1981 |
+
self.model = TurboSparseMixtralModel(config)
|
| 1982 |
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
| 1983 |
|
| 1984 |
# Initialize weights and apply final processing
|
|
|
|
| 2090 |
MIXTRAL_START_DOCSTRING,
|
| 2091 |
)
|
| 2092 |
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mixtral, LLAMA->MIXTRAL
|
| 2093 |
+
class TurboSparseMixtralForTokenClassification(TurboSparseMixtralPreTrainedModel):
|
| 2094 |
def __init__(self, config):
|
| 2095 |
super().__init__(config)
|
| 2096 |
self.num_labels = config.num_labels
|
| 2097 |
+
self.model = TurboSparseMixtralModel(config)
|
| 2098 |
if getattr(config, "classifier_dropout", None) is not None:
|
| 2099 |
classifier_dropout = config.classifier_dropout
|
| 2100 |
elif getattr(config, "hidden_dropout", None) is not None:
|