ToastyPigeon commited on
Commit
eca3b01
·
verified ·
1 Parent(s): 960b422

Upload folder using huggingface_hub

Browse files
Vostral Mini 3B (Text Only)_layers_output.png ADDED
chat_template.jinja ADDED
@@ -0,0 +1 @@
 
 
1
+ {% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% set messages = messages[1:] %}{% else %}{% set system_message = '' %}{% endif %}{{ bos_token }}{% if system_message != '' %}{{ '[SYSTEM_PROMPT]' + system_message + '[/SYSTEM_PROMPT]' }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST]' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}
config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MistralForCausalLM"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "bos_token_id": 1,
7
+ "eos_token_id": 2,
8
+ "head_dim": 128,
9
+ "auto_map": {
10
+ "AutoConfig": "configuration_loopstral.LoopstralConfig",
11
+ "AutoModelForCausalLM": "modeling_loopstral.LoopstralForCausalLM"
12
+ },
13
+ "hidden_act": "silu",
14
+ "hidden_size": 3072,
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": 8192,
17
+ "max_position_embeddings": 32768,
18
+ "model_type": "mistral",
19
+ "num_attention_heads": 32,
20
+ "num_hidden_layers": 30,
21
+ "num_key_value_heads": 8,
22
+ "layer_sequence": [[0,30,1]],
23
+ "rms_norm_eps": 1e-05,
24
+ "rope_theta": 100000000.0,
25
+ "sliding_window": null,
26
+ "tie_word_embeddings": false,
27
+ "torch_dtype": "bfloat16",
28
+ "transformers_version": "4.54.0",
29
+ "use_cache": true,
30
+ "vocab_size": 131072
31
+ }
configuration_loopstral.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Mistral model configuration"""
16
+
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class LoopstralConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an
27
+ Mistral model according to the specified arguments, defining the model architecture. Instantiating a configuration
28
+ with the defaults will yield a similar configuration to that of the Mistral-7B-v0.1 or Mistral-7B-Instruct-v0.1.
29
+
30
+ [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
31
+ [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
32
+
33
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
34
+ documentation from [`PretrainedConfig`] for more information.
35
+
36
+
37
+ Args:
38
+ vocab_size (`int`, *optional*, defaults to 32000):
39
+ Vocabulary size of the Mistral model. Defines the number of different tokens that can be represented by the
40
+ `inputs_ids` passed when calling [`MistralModel`]
41
+ hidden_size (`int`, *optional*, defaults to 4096):
42
+ Dimension of the hidden representations.
43
+ intermediate_size (`int`, *optional*, defaults to 14336):
44
+ Dimension of the MLP representations.
45
+ num_hidden_layers (`int`, *optional*, defaults to 32):
46
+ Number of hidden layers in the Transformer encoder.
47
+ num_attention_heads (`int`, *optional*, defaults to 32):
48
+ Number of attention heads for each attention layer in the Transformer encoder.
49
+ num_key_value_heads (`int`, *optional*, defaults to 8):
50
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
51
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
52
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
53
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
54
+ by meanpooling all the original heads within that group. For more details, check out [this
55
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `8`.
56
+ layer_sequence (list of 3-lists):
57
+ Order to do the layers in. Defaults to doing all the layers in order once.
58
+ List of lists, each sublist is 3 numbers, [start_layer, end_layer, num_loops]
59
+ head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
60
+ The attention head dimension.
61
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
62
+ The non-linear activation function (function or string) in the decoder.
63
+ max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
64
+ The maximum sequence length that this model might ever be used with. Mistral's sliding window attention
65
+ allows sequence of up to 4096*32 tokens.
66
+ initializer_range (`float`, *optional*, defaults to 0.02):
67
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
68
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
69
+ The epsilon used by the rms normalization layers.
70
+ use_cache (`bool`, *optional*, defaults to `True`):
71
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
72
+ relevant if `config.is_decoder=True`.
73
+ pad_token_id (`int`, *optional*):
74
+ The id of the padding token.
75
+ bos_token_id (`int`, *optional*, defaults to 1):
76
+ The id of the "beginning-of-sequence" token.
77
+ eos_token_id (`int`, *optional*, defaults to 2):
78
+ The id of the "end-of-sequence" token.
79
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
80
+ Whether the model's input and output word embeddings should be tied.
81
+ rope_theta (`float`, *optional*, defaults to 10000.0):
82
+ The base period of the RoPE embeddings.
83
+ sliding_window (`int`, *optional*, defaults to 4096):
84
+ Sliding window attention window size. If not specified, will default to `4096`.
85
+ attention_dropout (`float`, *optional*, defaults to 0.0):
86
+ The dropout ratio for the attention probabilities.
87
+
88
+ ```python
89
+ >>> from transformers import MistralModel, MistralConfig
90
+
91
+ >>> # Initializing a Mistral 7B style configuration
92
+ >>> configuration = MistralConfig()
93
+
94
+ >>> # Initializing a model from the Mistral 7B style configuration
95
+ >>> model = MistralModel(configuration)
96
+
97
+ >>> # Accessing the model configuration
98
+ >>> configuration = model.config
99
+ ```"""
100
+
101
+ model_type = "mistral"
102
+ keys_to_ignore_at_inference = ["past_key_values"]
103
+ # Default tensor parallel plan for base model `MistralModel`
104
+ base_model_tp_plan = {
105
+ "layers.*.self_attn.q_proj": "colwise",
106
+ "layers.*.self_attn.k_proj": "colwise",
107
+ "layers.*.self_attn.v_proj": "colwise",
108
+ "layers.*.self_attn.o_proj": "rowwise",
109
+ "layers.*.mlp.gate_proj": "colwise",
110
+ "layers.*.mlp.up_proj": "colwise",
111
+ "layers.*.mlp.down_proj": "rowwise",
112
+ }
113
+ base_model_pp_plan = {
114
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
115
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
116
+ "norm": (["hidden_states"], ["hidden_states"]),
117
+ }
118
+
119
+ def __init__(
120
+ self,
121
+ vocab_size=32000,
122
+ hidden_size=4096,
123
+ intermediate_size=14336,
124
+ num_hidden_layers=32,
125
+ num_attention_heads=32,
126
+ num_key_value_heads=8,
127
+ layer_sequence=[[0,32,1]],
128
+ head_dim=None,
129
+ hidden_act="silu",
130
+ max_position_embeddings=4096 * 32,
131
+ initializer_range=0.02,
132
+ rms_norm_eps=1e-6,
133
+ use_cache=True,
134
+ pad_token_id=None,
135
+ bos_token_id=1,
136
+ eos_token_id=2,
137
+ tie_word_embeddings=False,
138
+ rope_theta=10000.0,
139
+ sliding_window=4096,
140
+ attention_dropout=0.0,
141
+ **kwargs,
142
+ ):
143
+ self.vocab_size = vocab_size
144
+ self.max_position_embeddings = max_position_embeddings
145
+ self.hidden_size = hidden_size
146
+ self.intermediate_size = intermediate_size
147
+ self.num_hidden_layers = num_hidden_layers
148
+ self.num_attention_heads = num_attention_heads
149
+ self.sliding_window = sliding_window
150
+ self.head_dim = head_dim
151
+
152
+ # for backward compatibility
153
+ if num_key_value_heads is None:
154
+ num_key_value_heads = num_attention_heads
155
+
156
+ self.num_key_value_heads = num_key_value_heads
157
+ self.layer_sequence = layer_sequence
158
+ self.hidden_act = hidden_act
159
+ self.initializer_range = initializer_range
160
+ self.rms_norm_eps = rms_norm_eps
161
+ self.use_cache = use_cache
162
+ self.rope_theta = rope_theta
163
+ self.attention_dropout = attention_dropout
164
+
165
+ if "layer_types" in kwargs:
166
+ logger.warning_once(
167
+ "Detected Mistral model with layer_types. Consider using AutoModel or Ministral classes instead to enable alternating attention compatibility."
168
+ )
169
+
170
+ super().__init__(
171
+ pad_token_id=pad_token_id,
172
+ bos_token_id=bos_token_id,
173
+ eos_token_id=eos_token_id,
174
+ tie_word_embeddings=tie_word_embeddings,
175
+ **kwargs,
176
+ )
177
+
178
+
179
+ __all__ = ["LoopstralConfig"]
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "transformers_version": "4.54.0"
6
+ }
model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:56e81a56861107692d1aff7d38b998978e62d93cbf64458a9e64631673068e21
3
+ size 4983087160
model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ab3ef42207e9617581bd3b523c0f561574936457baabfc65090ec2afa212275
3
+ size 3045217160
model.safetensors.index.json ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_parameters": 4014136320,
4
+ "total_size": 8028272640
5
+ },
6
+ "weight_map": {
7
+ "lm_head.weight": "model-00002-of-00002.safetensors",
8
+ "model.embed_tokens.weight": "model-00001-of-00002.safetensors",
9
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
10
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
11
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
12
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
13
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
14
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
15
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
16
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
17
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
18
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
19
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
20
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
21
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
22
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
23
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
24
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
25
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
26
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
27
+ "model.layers.10.input_layernorm.weight": "model-00001-of-00002.safetensors",
28
+ "model.layers.10.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
29
+ "model.layers.10.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
30
+ "model.layers.10.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
31
+ "model.layers.10.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
32
+ "model.layers.10.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
33
+ "model.layers.10.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
34
+ "model.layers.10.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
35
+ "model.layers.10.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
36
+ "model.layers.11.input_layernorm.weight": "model-00001-of-00002.safetensors",
37
+ "model.layers.11.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
38
+ "model.layers.11.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
39
+ "model.layers.11.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
40
+ "model.layers.11.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
41
+ "model.layers.11.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
42
+ "model.layers.11.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
43
+ "model.layers.11.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
44
+ "model.layers.11.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
45
+ "model.layers.12.input_layernorm.weight": "model-00001-of-00002.safetensors",
46
+ "model.layers.12.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
47
+ "model.layers.12.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
48
+ "model.layers.12.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
49
+ "model.layers.12.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
50
+ "model.layers.12.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
51
+ "model.layers.12.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
52
+ "model.layers.12.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
53
+ "model.layers.12.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
54
+ "model.layers.13.input_layernorm.weight": "model-00001-of-00002.safetensors",
55
+ "model.layers.13.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
56
+ "model.layers.13.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
57
+ "model.layers.13.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
58
+ "model.layers.13.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
59
+ "model.layers.13.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
60
+ "model.layers.13.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
61
+ "model.layers.13.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
62
+ "model.layers.13.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
63
+ "model.layers.14.input_layernorm.weight": "model-00001-of-00002.safetensors",
64
+ "model.layers.14.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
65
+ "model.layers.14.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
66
+ "model.layers.14.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
67
+ "model.layers.14.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
68
+ "model.layers.14.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
69
+ "model.layers.14.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
70
+ "model.layers.14.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
71
+ "model.layers.14.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
72
+ "model.layers.15.input_layernorm.weight": "model-00001-of-00002.safetensors",
73
+ "model.layers.15.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
74
+ "model.layers.15.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
75
+ "model.layers.15.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
76
+ "model.layers.15.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
77
+ "model.layers.15.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
78
+ "model.layers.15.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
79
+ "model.layers.15.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
80
+ "model.layers.15.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
81
+ "model.layers.16.input_layernorm.weight": "model-00001-of-00002.safetensors",
82
+ "model.layers.16.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
83
+ "model.layers.16.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
84
+ "model.layers.16.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
85
+ "model.layers.16.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
86
+ "model.layers.16.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
87
+ "model.layers.16.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
88
+ "model.layers.16.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
89
+ "model.layers.16.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
90
+ "model.layers.17.input_layernorm.weight": "model-00001-of-00002.safetensors",
91
+ "model.layers.17.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
92
+ "model.layers.17.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
93
+ "model.layers.17.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
94
+ "model.layers.17.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
95
+ "model.layers.17.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
96
+ "model.layers.17.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
97
+ "model.layers.17.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
98
+ "model.layers.17.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
99
+ "model.layers.18.input_layernorm.weight": "model-00001-of-00002.safetensors",
100
+ "model.layers.18.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
101
+ "model.layers.18.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
102
+ "model.layers.18.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
103
+ "model.layers.18.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
104
+ "model.layers.18.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
105
+ "model.layers.18.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
106
+ "model.layers.18.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
107
+ "model.layers.18.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
108
+ "model.layers.19.input_layernorm.weight": "model-00002-of-00002.safetensors",
109
+ "model.layers.19.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
110
+ "model.layers.19.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
111
+ "model.layers.19.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
112
+ "model.layers.19.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
113
+ "model.layers.19.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
114
+ "model.layers.19.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
115
+ "model.layers.19.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
116
+ "model.layers.19.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
117
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
118
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
119
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
120
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
121
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
122
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
123
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
124
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
125
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
126
+ "model.layers.20.input_layernorm.weight": "model-00002-of-00002.safetensors",
127
+ "model.layers.20.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
128
+ "model.layers.20.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
129
+ "model.layers.20.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
130
+ "model.layers.20.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
131
+ "model.layers.20.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
132
+ "model.layers.20.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
133
+ "model.layers.20.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
134
+ "model.layers.20.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
135
+ "model.layers.21.input_layernorm.weight": "model-00002-of-00002.safetensors",
136
+ "model.layers.21.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
137
+ "model.layers.21.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
138
+ "model.layers.21.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
139
+ "model.layers.21.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
140
+ "model.layers.21.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
141
+ "model.layers.21.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
142
+ "model.layers.21.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
143
+ "model.layers.21.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
144
+ "model.layers.22.input_layernorm.weight": "model-00002-of-00002.safetensors",
145
+ "model.layers.22.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
146
+ "model.layers.22.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
147
+ "model.layers.22.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
148
+ "model.layers.22.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
149
+ "model.layers.22.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
150
+ "model.layers.22.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
151
+ "model.layers.22.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
152
+ "model.layers.22.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
153
+ "model.layers.23.input_layernorm.weight": "model-00002-of-00002.safetensors",
154
+ "model.layers.23.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
155
+ "model.layers.23.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
156
+ "model.layers.23.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
157
+ "model.layers.23.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
158
+ "model.layers.23.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
159
+ "model.layers.23.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
160
+ "model.layers.23.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
161
+ "model.layers.23.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
162
+ "model.layers.24.input_layernorm.weight": "model-00002-of-00002.safetensors",
163
+ "model.layers.24.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
164
+ "model.layers.24.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
165
+ "model.layers.24.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
166
+ "model.layers.24.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
167
+ "model.layers.24.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
168
+ "model.layers.24.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
169
+ "model.layers.24.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
170
+ "model.layers.24.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
171
+ "model.layers.25.input_layernorm.weight": "model-00002-of-00002.safetensors",
172
+ "model.layers.25.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
173
+ "model.layers.25.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
174
+ "model.layers.25.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
175
+ "model.layers.25.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
176
+ "model.layers.25.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
177
+ "model.layers.25.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
178
+ "model.layers.25.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
179
+ "model.layers.25.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
180
+ "model.layers.26.input_layernorm.weight": "model-00002-of-00002.safetensors",
181
+ "model.layers.26.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
182
+ "model.layers.26.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
183
+ "model.layers.26.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
184
+ "model.layers.26.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
185
+ "model.layers.26.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
186
+ "model.layers.26.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
187
+ "model.layers.26.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
188
+ "model.layers.26.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
189
+ "model.layers.27.input_layernorm.weight": "model-00002-of-00002.safetensors",
190
+ "model.layers.27.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
191
+ "model.layers.27.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
192
+ "model.layers.27.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
193
+ "model.layers.27.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
194
+ "model.layers.27.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
195
+ "model.layers.27.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
196
+ "model.layers.27.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
197
+ "model.layers.27.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
198
+ "model.layers.28.input_layernorm.weight": "model-00002-of-00002.safetensors",
199
+ "model.layers.28.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
200
+ "model.layers.28.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
201
+ "model.layers.28.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
202
+ "model.layers.28.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
203
+ "model.layers.28.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
204
+ "model.layers.28.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
205
+ "model.layers.28.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
206
+ "model.layers.28.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
207
+ "model.layers.29.input_layernorm.weight": "model-00002-of-00002.safetensors",
208
+ "model.layers.29.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
209
+ "model.layers.29.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
210
+ "model.layers.29.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
211
+ "model.layers.29.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
212
+ "model.layers.29.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
213
+ "model.layers.29.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
214
+ "model.layers.29.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
215
+ "model.layers.29.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
216
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
217
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
218
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
219
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
220
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
221
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
222
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
223
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
224
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
225
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00002.safetensors",
226
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
227
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
228
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
229
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
230
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
231
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
232
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
233
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
234
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00002.safetensors",
235
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
236
+ "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
237
+ "model.layers.5.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
238
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
239
+ "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
240
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
241
+ "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
242
+ "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
243
+ "model.layers.6.input_layernorm.weight": "model-00001-of-00002.safetensors",
244
+ "model.layers.6.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
245
+ "model.layers.6.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
246
+ "model.layers.6.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
247
+ "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
248
+ "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
249
+ "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
250
+ "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
251
+ "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
252
+ "model.layers.7.input_layernorm.weight": "model-00001-of-00002.safetensors",
253
+ "model.layers.7.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
254
+ "model.layers.7.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
255
+ "model.layers.7.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
256
+ "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
257
+ "model.layers.7.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
258
+ "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
259
+ "model.layers.7.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
260
+ "model.layers.7.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
261
+ "model.layers.8.input_layernorm.weight": "model-00001-of-00002.safetensors",
262
+ "model.layers.8.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
263
+ "model.layers.8.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
264
+ "model.layers.8.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
265
+ "model.layers.8.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
266
+ "model.layers.8.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
267
+ "model.layers.8.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
268
+ "model.layers.8.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
269
+ "model.layers.8.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
270
+ "model.layers.9.input_layernorm.weight": "model-00001-of-00002.safetensors",
271
+ "model.layers.9.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
272
+ "model.layers.9.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
273
+ "model.layers.9.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
274
+ "model.layers.9.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
275
+ "model.layers.9.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
276
+ "model.layers.9.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
277
+ "model.layers.9.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
278
+ "model.layers.9.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
279
+ "model.norm.weight": "model-00002-of-00002.safetensors"
280
+ }
281
+ }
modeling_loopstral.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/mistral/modular_mistral.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_mistral.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ from typing import Callable, Optional, Union
8
+
9
+ import torch
10
+ from torch import nn
11
+ from torch.nn import CrossEntropyLoss
12
+
13
+ #from transformers.modeling_utils import check_model_inputs
14
+
15
+ from transformers.activations import ACT2FN
16
+ from transformers.cache_utils import Cache, DynamicCache
17
+ from transformers.generation import GenerationMixin
18
+ from transformers.integrations import use_kernel_forward_from_hub
19
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
20
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
21
+ from transformers.modeling_layers import (
22
+ GenericForQuestionAnswering,
23
+ GenericForSequenceClassification,
24
+ GenericForTokenClassification,
25
+ GradientCheckpointingLayer,
26
+ )
27
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
28
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
29
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
30
+ from transformers.processing_utils import Unpack
31
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
32
+ #from transformers.utils.doc import auto_docstring
33
+ from transformers.utils.deprecation import deprecate_kwarg
34
+ from .configuration_loopstral import LoopstralConfig
35
+
36
+
37
+ class MistralMLP(nn.Module):
38
+ def __init__(self, config):
39
+ super().__init__()
40
+ self.config = config
41
+ self.hidden_size = config.hidden_size
42
+ self.intermediate_size = config.intermediate_size
43
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
44
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
45
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
46
+ self.act_fn = ACT2FN[config.hidden_act]
47
+
48
+ def forward(self, x):
49
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
50
+ return down_proj
51
+
52
+
53
+ def rotate_half(x):
54
+ """Rotates half the hidden dims of the input."""
55
+ x1 = x[..., : x.shape[-1] // 2]
56
+ x2 = x[..., x.shape[-1] // 2 :]
57
+ return torch.cat((-x2, x1), dim=-1)
58
+
59
+
60
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
61
+ """Applies Rotary Position Embedding to the query and key tensors.
62
+
63
+ Args:
64
+ q (`torch.Tensor`): The query tensor.
65
+ k (`torch.Tensor`): The key tensor.
66
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
67
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
68
+ position_ids (`torch.Tensor`, *optional*):
69
+ Deprecated and unused.
70
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
71
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
72
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
73
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
74
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
75
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
76
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
77
+ Returns:
78
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
79
+ """
80
+ cos = cos.unsqueeze(unsqueeze_dim)
81
+ sin = sin.unsqueeze(unsqueeze_dim)
82
+ q_embed = (q * cos) + (rotate_half(q) * sin)
83
+ k_embed = (k * cos) + (rotate_half(k) * sin)
84
+ return q_embed, k_embed
85
+
86
+
87
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
88
+ """
89
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
90
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
91
+ """
92
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
93
+ if n_rep == 1:
94
+ return hidden_states
95
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
96
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
97
+
98
+
99
+ def eager_attention_forward(
100
+ module: nn.Module,
101
+ query: torch.Tensor,
102
+ key: torch.Tensor,
103
+ value: torch.Tensor,
104
+ attention_mask: Optional[torch.Tensor],
105
+ scaling: float,
106
+ dropout: float = 0.0,
107
+ **kwargs: Unpack[TransformersKwargs],
108
+ ):
109
+ key_states = repeat_kv(key, module.num_key_value_groups)
110
+ value_states = repeat_kv(value, module.num_key_value_groups)
111
+
112
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
113
+ if attention_mask is not None:
114
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
115
+ attn_weights = attn_weights + causal_mask
116
+
117
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
118
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
119
+ attn_output = torch.matmul(attn_weights, value_states)
120
+ attn_output = attn_output.transpose(1, 2).contiguous()
121
+
122
+ return attn_output, attn_weights
123
+
124
+
125
+ class MistralAttention(nn.Module):
126
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
127
+
128
+ def __init__(self, config: LoopstralConfig, layer_idx: int):
129
+ super().__init__()
130
+ self.config = config
131
+ self.layer_idx = layer_idx
132
+ self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
133
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
134
+ self.scaling = self.head_dim**-0.5
135
+ self.attention_dropout = config.attention_dropout
136
+ self.is_causal = True
137
+ self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
138
+ self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
139
+ self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
140
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
141
+
142
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
143
+ def forward(
144
+ self,
145
+ hidden_states: torch.Tensor,
146
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
147
+ attention_mask: Optional[torch.Tensor],
148
+ past_key_values: Optional[Cache] = None,
149
+ cache_position: Optional[torch.LongTensor] = None,
150
+ update_cache: Optional[bool] = True,
151
+ **kwargs: Unpack[FlashAttentionKwargs],
152
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
153
+ input_shape = hidden_states.shape[:-1]
154
+ hidden_shape = (*input_shape, -1, self.head_dim)
155
+
156
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
157
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
158
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
159
+
160
+ cos, sin = position_embeddings
161
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
162
+
163
+ if past_key_values is not None:
164
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
165
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
166
+
167
+ # --- START DEBUGGING CODE ---
168
+ # Add these two lines to see what the object is and what's inside it.
169
+ #print(f"DEBUG: Type of cache object: {type(past_key_values)}")
170
+ #print(f"DEBUG: Attributes of cache object: {dir(past_key_values)}")
171
+ # --- END DEBUGGING CODE ---
172
+
173
+ if update_cache:
174
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
175
+ else:
176
+ k_cache, v_cache = past_key_values[self.layer_idx]
177
+ if k_cache is not None:
178
+ key_states = torch.cat([k_cache, key_states], dim=2)
179
+ value_states = torch.cat([v_cache, value_states], dim=2)
180
+
181
+ attention_interface: Callable = eager_attention_forward
182
+ if self.config._attn_implementation != "eager":
183
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
184
+
185
+ attn_output, attn_weights = attention_interface(
186
+ self,
187
+ query_states,
188
+ key_states,
189
+ value_states,
190
+ attention_mask,
191
+ dropout=0.0 if not self.training else self.attention_dropout,
192
+ scaling=self.scaling,
193
+ sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
194
+ **kwargs,
195
+ )
196
+
197
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
198
+ attn_output = self.o_proj(attn_output)
199
+ return attn_output, attn_weights
200
+
201
+
202
+ @use_kernel_forward_from_hub("RMSNorm")
203
+ class MistralRMSNorm(nn.Module):
204
+ def __init__(self, hidden_size, eps=1e-6):
205
+ """
206
+ MistralRMSNorm is equivalent to T5LayerNorm
207
+ """
208
+ super().__init__()
209
+ self.weight = nn.Parameter(torch.ones(hidden_size))
210
+ self.variance_epsilon = eps
211
+
212
+ def forward(self, hidden_states):
213
+ input_dtype = hidden_states.dtype
214
+ hidden_states = hidden_states.to(torch.float32)
215
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
216
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
217
+ return self.weight * hidden_states.to(input_dtype)
218
+
219
+ def extra_repr(self):
220
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
221
+
222
+
223
+ class MistralDecoderLayer(GradientCheckpointingLayer):
224
+ def __init__(self, config: LoopstralConfig, layer_idx: int):
225
+ super().__init__()
226
+ self.hidden_size = config.hidden_size
227
+ self.self_attn = MistralAttention(config=config, layer_idx=layer_idx)
228
+ self.mlp = MistralMLP(config)
229
+ self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
230
+ self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
231
+
232
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
233
+ def forward(
234
+ self,
235
+ hidden_states: torch.Tensor,
236
+ attention_mask: Optional[torch.Tensor] = None,
237
+ position_ids: Optional[torch.LongTensor] = None,
238
+ past_key_values: Optional[Cache] = None,
239
+ use_cache: Optional[bool] = False,
240
+ cache_position: Optional[torch.LongTensor] = None,
241
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
242
+ update_cache: Optional[bool] = True,
243
+ **kwargs: Unpack[TransformersKwargs],
244
+ ) -> torch.Tensor:
245
+ residual = hidden_states
246
+ hidden_states = self.input_layernorm(hidden_states)
247
+ # Self Attention
248
+ hidden_states, _ = self.self_attn(
249
+ hidden_states=hidden_states,
250
+ attention_mask=attention_mask,
251
+ position_ids=position_ids,
252
+ past_key_values=past_key_values,
253
+ use_cache=use_cache,
254
+ cache_position=cache_position,
255
+ position_embeddings=position_embeddings,
256
+ update_cache=update_cache,
257
+ **kwargs,
258
+ )
259
+ hidden_states = residual + hidden_states
260
+
261
+ # Fully Connected
262
+ residual = hidden_states
263
+ hidden_states = self.post_attention_layernorm(hidden_states)
264
+ hidden_states = self.mlp(hidden_states)
265
+ hidden_states = residual + hidden_states
266
+ return hidden_states
267
+
268
+
269
+ @auto_docstring
270
+ class MistralPreTrainedModel(PreTrainedModel):
271
+ config: LoopstralConfig
272
+ base_model_prefix = "model"
273
+ supports_gradient_checkpointing = True
274
+ _no_split_modules = ["MistralDecoderLayer"]
275
+ _skip_keys_device_placement = ["past_key_values"]
276
+ _supports_flash_attn = True
277
+ _supports_sdpa = True
278
+ _supports_flex_attn = True
279
+
280
+ _can_compile_fullgraph = True
281
+ _supports_attention_backend = True
282
+ _can_record_outputs = {
283
+ "hidden_states": MistralDecoderLayer,
284
+ "attentions": MistralAttention,
285
+ }
286
+
287
+
288
+ class MistralRotaryEmbedding(nn.Module):
289
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
290
+
291
+ def __init__(self, config: LoopstralConfig, device=None):
292
+ super().__init__()
293
+ # BC: "rope_type" was originally "type"
294
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
295
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
296
+ else:
297
+ self.rope_type = "default"
298
+ self.max_seq_len_cached = config.max_position_embeddings
299
+ self.original_max_seq_len = config.max_position_embeddings
300
+
301
+ self.config = config
302
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
303
+
304
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
305
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
306
+ self.original_inv_freq = self.inv_freq
307
+
308
+ @torch.no_grad()
309
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
310
+ def forward(self, x, position_ids):
311
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
312
+ position_ids_expanded = position_ids[:, None, :].float()
313
+
314
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
315
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
316
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
317
+ emb = torch.cat((freqs, freqs), dim=-1)
318
+ cos = emb.cos() * self.attention_scaling
319
+ sin = emb.sin() * self.attention_scaling
320
+
321
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
322
+
323
+
324
+ @auto_docstring
325
+ class LoopstralModel(MistralPreTrainedModel):
326
+ def __init__(self, config: LoopstralConfig):
327
+ super().__init__(config)
328
+ self.padding_idx = config.pad_token_id
329
+ self.vocab_size = config.vocab_size
330
+
331
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
332
+ self.layers = nn.ModuleList(
333
+ [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
334
+ )
335
+ self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
336
+ self.rotary_emb = MistralRotaryEmbedding(config=config)
337
+ self.gradient_checkpointing = False
338
+
339
+ # Initialize weights and apply final processing
340
+ self.post_init()
341
+
342
+ #@check_model_inputs
343
+ @auto_docstring
344
+ def forward(
345
+ self,
346
+ input_ids: Optional[torch.LongTensor] = None,
347
+ attention_mask: Optional[torch.Tensor] = None,
348
+ position_ids: Optional[torch.LongTensor] = None,
349
+ past_key_values: Optional[Cache] = None,
350
+ inputs_embeds: Optional[torch.FloatTensor] = None,
351
+ use_cache: Optional[bool] = None,
352
+ cache_position: Optional[torch.LongTensor] = None,
353
+ **kwargs: Unpack[TransformersKwargs],
354
+ ) -> BaseModelOutputWithPast:
355
+ if (input_ids is None) ^ (inputs_embeds is not None):
356
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
357
+
358
+ if inputs_embeds is None:
359
+ inputs_embeds = self.embed_tokens(input_ids)
360
+
361
+ if use_cache and past_key_values is None:
362
+ past_key_values = DynamicCache(config=self.config)
363
+
364
+ if cache_position is None:
365
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
366
+ cache_position = torch.arange(
367
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
368
+ )
369
+
370
+ if position_ids is None:
371
+ position_ids = cache_position.unsqueeze(0)
372
+
373
+ mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
374
+ causal_mask = mask_function(
375
+ config=self.config,
376
+ input_embeds=inputs_embeds,
377
+ attention_mask=attention_mask,
378
+ cache_position=cache_position,
379
+ past_key_values=past_key_values,
380
+ position_ids=position_ids,
381
+ )
382
+
383
+ hidden_states = inputs_embeds
384
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
385
+ #***
386
+ #Create the loop sequence!
387
+ #***
388
+ l_seq = []
389
+ #print(self.config.layer_sequence)
390
+ for seq in self.config.layer_sequence:
391
+ l_seq += [i for i in range(seq[0],min(seq[1], self.config.num_hidden_layers))]*seq[2]
392
+ #print(f"DEBUG: Layer sequence {l_seq}")
393
+
394
+ last_visit_map = {layer_idx: i for i, layer_idx in enumerate(l_seq)}
395
+ for i, layer in enumerate(l_seq):
396
+ should_update_cache = use_cache and (last_visit_map[layer] == i)
397
+ decoder_layer = self.layers[layer]
398
+ hidden_states = decoder_layer(
399
+ hidden_states,
400
+ attention_mask=causal_mask,
401
+ position_ids=position_ids,
402
+ past_key_values=past_key_values,
403
+ use_cache=use_cache,
404
+ cache_position=cache_position,
405
+ position_embeddings=position_embeddings,
406
+ update_cache=should_update_cache,
407
+ **kwargs,
408
+ )
409
+ hidden_states = self.norm(hidden_states)
410
+ return BaseModelOutputWithPast(
411
+ last_hidden_state=hidden_states,
412
+ past_key_values=past_key_values if use_cache else None,
413
+ )
414
+
415
+
416
+ @auto_docstring
417
+ class LoopstralForCausalLM(MistralPreTrainedModel, GenerationMixin):
418
+ _tied_weights_keys = ["lm_head.weight"]
419
+ _tp_plan = {"lm_head": "colwise_rep"}
420
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
421
+
422
+ def __init__(self, config):
423
+ super().__init__(config)
424
+ self.model = LoopstralModel(config)
425
+ self.vocab_size = config.vocab_size
426
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
427
+
428
+ # Initialize weights and apply final processing
429
+ self.post_init()
430
+
431
+ @can_return_tuple
432
+ @auto_docstring
433
+ def forward(
434
+ self,
435
+ input_ids: Optional[torch.LongTensor] = None,
436
+ attention_mask: Optional[torch.Tensor] = None,
437
+ position_ids: Optional[torch.LongTensor] = None,
438
+ past_key_values: Optional[Cache] = None,
439
+ inputs_embeds: Optional[torch.FloatTensor] = None,
440
+ labels: Optional[torch.LongTensor] = None,
441
+ use_cache: Optional[bool] = None,
442
+ cache_position: Optional[torch.LongTensor] = None,
443
+ logits_to_keep: Union[int, torch.Tensor] = 0,
444
+ **kwargs: Unpack[TransformersKwargs],
445
+ ) -> CausalLMOutputWithPast:
446
+ r"""
447
+ Example:
448
+
449
+ ```python
450
+ >>> from transformers import AutoTokenizer, MistralForCausalLM
451
+
452
+ >>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf")
453
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf")
454
+
455
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
456
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
457
+
458
+ >>> # Generate
459
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
460
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
461
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
462
+ ```"""
463
+ outputs: BaseModelOutputWithPast = self.model(
464
+ input_ids=input_ids,
465
+ attention_mask=attention_mask,
466
+ position_ids=position_ids,
467
+ past_key_values=past_key_values,
468
+ inputs_embeds=inputs_embeds,
469
+ use_cache=use_cache,
470
+ cache_position=cache_position,
471
+ **kwargs,
472
+ )
473
+
474
+ hidden_states = outputs.last_hidden_state
475
+ logits = self.lm_head(hidden_states)
476
+
477
+ loss = None
478
+ if labels is not None:
479
+ # THE FIX IS HERE: Standard loss calculation
480
+ # Shift so that tokens < n predict n
481
+ shift_logits = logits[..., :-1, :].contiguous()
482
+ shift_labels = labels[..., 1:].contiguous()
483
+ # Flatten the tokens
484
+ loss_fct = CrossEntropyLoss()
485
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
486
+ shift_labels = shift_labels.view(-1)
487
+ # Enable model parallelism
488
+ shift_labels = shift_labels.to(shift_logits.device)
489
+ loss = loss_fct(shift_logits, shift_labels)
490
+
491
+ return CausalLMOutputWithPast(
492
+ loss=loss,
493
+ logits=logits,
494
+ past_key_values=outputs.past_key_values,
495
+ hidden_states=outputs.hidden_states,
496
+ attentions=outputs.attentions,
497
+ )
498
+
499
+
500
+ class MistralForTokenClassification(GenericForTokenClassification, MistralPreTrainedModel):
501
+ pass
502
+
503
+
504
+ class MistralForSequenceClassification(GenericForSequenceClassification, MistralPreTrainedModel):
505
+ pass
506
+
507
+
508
+ class MistralForQuestionAnswering(GenericForQuestionAnswering, MistralPreTrainedModel): ...
509
+
510
+
511
+ __all__ = [
512
+ "LoopstralForCausalLM",
513
+ "MistralForQuestionAnswering",
514
+ "LoopstralModel",
515
+ "MistralPreTrainedModel",
516
+ "MistralForSequenceClassification",
517
+ "MistralForTokenClassification",
518
+ ]
saved_encodings.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a645aed1642548fb3dcdb228c21cc80df434cc0623ee54cad973de53ae19d0d5
3
+ size 24767816
special_tokens_map.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff