PursuitOfDataScience commited on
Commit
95e5b13
·
verified ·
1 Parent(s): f8f14b8

Upload model with sharded safetensors

Browse files
.gitattributes ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ model-00001-of-00005.safetensors filter=lfs diff=lfs merge=lfs -text
2
+ model-00002-of-00005.safetensors filter=lfs diff=lfs merge=lfs -text
3
+ model-00003-of-00005.safetensors filter=lfs diff=lfs merge=lfs -text
4
+ model-00004-of-00005.safetensors filter=lfs diff=lfs merge=lfs -text
5
+ model-00005-of-00005.safetensors filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ library_name: transformers
6
+ tags:
7
+ - text-generation
8
+ - causal-lm
9
+ - transformer
10
+ - argonne
11
+ - pretrained
12
+ pipeline_tag: text-generation
13
+ ---
14
+
15
+ # Argonne 2.0
16
+
17
+ A **4.9 billion parameter** decoder-only transformer language model trained from scratch.
18
+
19
+ ## Model Architecture
20
+
21
+ | Component | Specification |
22
+ |-----------|--------------|
23
+ | **Parameters** | ~4.9B |
24
+ | **Layers** | 24 transformer blocks |
25
+ | **Hidden Size** | 4,080 |
26
+ | **Attention Heads** | 24 query / 8 key-value (GQA) |
27
+ | **Context Length** | 4,096 tokens |
28
+ | **Vocabulary Size** | 151,665 |
29
+
30
+ ## Usage
31
+
32
+ ```python
33
+ from transformers import AutoModelForCausalLM, AutoTokenizer
34
+ import torch
35
+
36
+ model = AutoModelForCausalLM.from_pretrained(
37
+ "PursuitOfDataScience/Argonne-2.0",
38
+ torch_dtype=torch.bfloat16,
39
+ device_map="auto",
40
+ trust_remote_code=True
41
+ )
42
+ tokenizer = AutoTokenizer.from_pretrained("PursuitOfDataScience/Argonne-2.0", trust_remote_code=True)
43
+
44
+ prompt = "The future of AI is"
45
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
46
+ outputs = model.generate(**inputs, max_length=256, do_sample=True, temperature=0.7)
47
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
48
+ ```
49
+
50
+ ## License
51
+
52
+ Apache 2.0
53
+
54
+ ## Citation
55
+
56
+ ```bibtex
57
+ @misc{argonne2,
58
+ author = {PursuitOfDataScience},
59
+ title = {Argonne 2.0: A 4.9B Parameter Language Model},
60
+ year = {2026},
61
+ publisher = {Hugging Face},
62
+ url = {https://huggingface.co/PursuitOfDataScience/Argonne-2.0}
63
+ }
64
+ ```
65
+
66
+ ## Links
67
+
68
+ - GitHub: [PursuitOfDataScience](https://github.com/PursuitOfDataScience)
69
+ - Hugging Face: [PursuitOfDataScience](https://huggingface.co/PursuitOfDataScience)
added_tokens.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</tool_call>": 151658,
3
+ "<tool_call>": 151657,
4
+ "<|box_end|>": 151649,
5
+ "<|box_start|>": 151648,
6
+ "<|endoftext|>": 151643,
7
+ "<|file_sep|>": 151664,
8
+ "<|fim_middle|>": 151660,
9
+ "<|fim_pad|>": 151662,
10
+ "<|fim_prefix|>": 151659,
11
+ "<|fim_suffix|>": 151661,
12
+ "<|im_end|>": 151645,
13
+ "<|im_start|>": 151644,
14
+ "<|image_pad|>": 151655,
15
+ "<|object_ref_end|>": 151647,
16
+ "<|object_ref_start|>": 151646,
17
+ "<|quad_end|>": 151651,
18
+ "<|quad_start|>": 151650,
19
+ "<|repo_name|>": 151663,
20
+ "<|video_pad|>": 151656,
21
+ "<|vision_end|>": 151653,
22
+ "<|vision_pad|>": 151654,
23
+ "<|vision_start|>": 151652
24
+ }
config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ArgonneModel"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "block_size": 4096,
8
+ "eos_token_id": 151645,
9
+ "hidden_dropout": 0.0,
10
+ "hidden_size": 4080,
11
+ "intermediate_size": 11008,
12
+ "max_position_embeddings": 4096,
13
+ "mlp_bias": false,
14
+ "model_type": "argonne2",
15
+ "n_embd": 4080,
16
+ "n_head": 24,
17
+ "n_layer": 24,
18
+ "num_attention_heads": 24,
19
+ "num_hidden_layers": 24,
20
+ "num_key_value_heads": 8,
21
+ "pad_token_id": 151643,
22
+ "rms_norm_eps": 1e-06,
23
+ "rope_theta": 10000.0,
24
+ "sliding_window": null,
25
+ "torch_dtype": "float32",
26
+ "transformers_version": "4.44.0",
27
+ "use_flash_attention": true,
28
+ "use_gradient_checkpointing": false,
29
+ "vocab_size": 151665
30
+ }
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "eos_token_id": 151645,
4
+ "pad_token_id": 151643,
5
+ "transformers_version": "4.44.0"
6
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model-00001-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2d8d8c974aba12ebc3a6d89d854fe13145b6a12e0530bccc486071e3e4d15957
3
+ size 5309647128
model-00002-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60ca8b990e3a42ee3d928d45e5a2c4c439b4e819c5743ed79cd1089e57bc62b6
3
+ size 5209869352
model-00003-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de22769975a3f0760e6490c487f415fd877eacd3696deeccf5b697674cb44c59
3
+ size 5209869376
model-00004-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e956c8268b484a8891958ed9ea663a92f50a4a06d8489492c8509d2df09017e5
3
+ size 5351921856
model-00005-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fbb1caffb0be41117446f9ca814c4faae0688b9dcaf1146893f23b7ec163f1eb
3
+ size 1066179000
model.py ADDED
@@ -0,0 +1,852 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import importlib.util
3
+ from bisect import bisect_left, bisect_right
4
+ from typing import List, Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from transformers import (
10
+ AutoConfig,
11
+ AutoModel,
12
+ AutoModelForCausalLM,
13
+ PreTrainedModel,
14
+ PretrainedConfig,
15
+ )
16
+ from transformers.modeling_outputs import CausalLMOutput
17
+
18
+
19
+ _flash_attn_available = importlib.util.find_spec("flash_attn") is not None
20
+ if _flash_attn_available:
21
+ from flash_attn.flash_attn_interface import flash_attn_func
22
+
23
+
24
+ class ArgonneConfig(PretrainedConfig):
25
+ """Configuration for the Argonne v2 family of models."""
26
+
27
+ model_type = "argonne2"
28
+
29
+ def __init__(
30
+ self,
31
+ vocab_size: int = 32000,
32
+ hidden_size: int = 4096,
33
+ num_hidden_layers: int = 48,
34
+ num_attention_heads: int = 32,
35
+ num_key_value_heads: Optional[int] = None,
36
+ intermediate_size: Optional[int] = None,
37
+ max_position_embeddings: int = 4096,
38
+ attention_dropout: float = 0.0,
39
+ hidden_dropout: float = 0.0,
40
+ rms_norm_eps: float = 1e-6,
41
+ rope_theta: float = 10000.0,
42
+ sliding_window: Optional[int] = None,
43
+ use_flash_attention: bool = True,
44
+ use_gradient_checkpointing: bool = False,
45
+ tie_word_embeddings: bool = True,
46
+ attention_bias: bool = False,
47
+ mlp_bias: bool = False,
48
+ pad_token_id: Optional[int] = None,
49
+ bos_token_id: Optional[int] = None,
50
+ eos_token_id: Optional[int] = None,
51
+ **kwargs,
52
+ ) -> None:
53
+ pad_token_id = pad_token_id if pad_token_id is not None else kwargs.pop("pad_token_id", None)
54
+ bos_token_id = bos_token_id if bos_token_id is not None else kwargs.pop("bos_token_id", None)
55
+ eos_token_id = eos_token_id if eos_token_id is not None else kwargs.pop("eos_token_id", None)
56
+
57
+ super().__init__(
58
+ pad_token_id=pad_token_id,
59
+ bos_token_id=bos_token_id,
60
+ eos_token_id=eos_token_id,
61
+ **kwargs,
62
+ )
63
+ # Backwards compatibility with Argonne 1.x naming.
64
+ if "n_layer" in kwargs:
65
+ num_hidden_layers = kwargs["n_layer"]
66
+ if "n_head" in kwargs:
67
+ num_attention_heads = kwargs["n_head"]
68
+ if "n_embd" in kwargs:
69
+ hidden_size = kwargs["n_embd"]
70
+ if "block_size" in kwargs:
71
+ max_position_embeddings = kwargs["block_size"]
72
+
73
+ self.vocab_size = vocab_size
74
+ self.hidden_size = hidden_size
75
+ self.num_hidden_layers = num_hidden_layers
76
+ self.num_attention_heads = num_attention_heads
77
+ self.num_key_value_heads = (
78
+ num_key_value_heads if num_key_value_heads is not None else num_attention_heads // 2
79
+ )
80
+ if self.num_key_value_heads < 1:
81
+ self.num_key_value_heads = 1
82
+ if num_attention_heads % self.num_key_value_heads != 0:
83
+ raise ValueError("num_attention_heads must be divisible by num_key_value_heads")
84
+
85
+ if intermediate_size is None:
86
+ width = int(8 * hidden_size / 3)
87
+ self.intermediate_size = ((width + 255) // 256) * 256
88
+ else:
89
+ self.intermediate_size = intermediate_size
90
+
91
+ self.max_position_embeddings = max_position_embeddings
92
+ self.attention_dropout = attention_dropout
93
+ self.hidden_dropout = hidden_dropout
94
+ self.rms_norm_eps = rms_norm_eps
95
+ self.rope_theta = rope_theta
96
+ self.sliding_window = sliding_window
97
+ self.use_flash_attention = use_flash_attention
98
+ self.use_gradient_checkpointing = use_gradient_checkpointing
99
+ self.tie_word_embeddings = tie_word_embeddings
100
+ self.attention_bias = attention_bias
101
+ self.mlp_bias = mlp_bias
102
+
103
+ if self.pad_token_id is None and self.eos_token_id is not None:
104
+ self.pad_token_id = self.eos_token_id
105
+
106
+ # Backwards compatibility aliases
107
+ self.n_embd = self.hidden_size
108
+ self.n_layer = self.num_hidden_layers
109
+ self.n_head = self.num_attention_heads
110
+ self.block_size = self.max_position_embeddings
111
+
112
+
113
+ class RMSNorm(nn.Module):
114
+ def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
115
+ super().__init__()
116
+ self.eps = eps
117
+ self.weight = nn.Parameter(torch.ones(hidden_size))
118
+
119
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
120
+ orig_dtype = x.dtype
121
+ x = x.to(torch.float32)
122
+ # Clamp values to prevent overflow in pow(2)
123
+ x = torch.clamp(x, min=-65504.0, max=65504.0)
124
+ variance = x.pow(2).mean(-1, keepdim=True)
125
+ x = x * torch.rsqrt(variance + self.eps)
126
+ return (self.weight * x.to(orig_dtype))
127
+
128
+
129
+ class RotaryEmbedding(nn.Module):
130
+ def __init__(
131
+ self,
132
+ dim: int,
133
+ max_position_embeddings: int = 2048,
134
+ base: float = 10000.0,
135
+ device: Optional[torch.device] = None,
136
+ ) -> None:
137
+ super().__init__()
138
+ self.dim = dim
139
+ self.max_position_embeddings = max_position_embeddings
140
+ self.base = base
141
+
142
+ inv_freq = 1.0 / (
143
+ self.base
144
+ ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
145
+ )
146
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
147
+ self._set_cos_sin_cache(max_position_embeddings, device or inv_freq.device, torch.get_default_dtype())
148
+
149
+ def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> None:
150
+ self.max_seq_len_cached = seq_len
151
+ t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
152
+ freqs = torch.outer(t, self.inv_freq)
153
+ emb = torch.cat((freqs, freqs), dim=-1)
154
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
155
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
156
+
157
+ def forward(self, x: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
158
+ if seq_len > self.max_seq_len_cached:
159
+ self._set_cos_sin_cache(seq_len, x.device, x.dtype)
160
+ return (
161
+ self.cos_cached[:seq_len].to(dtype=x.dtype, device=x.device),
162
+ self.sin_cached[:seq_len].to(dtype=x.dtype, device=x.device),
163
+ )
164
+
165
+
166
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
167
+ x1 = x[..., : x.shape[-1] // 2]
168
+ x2 = x[..., x.shape[-1] // 2 :]
169
+ return torch.cat((-x2, x1), dim=-1)
170
+
171
+
172
+ def apply_rotary_pos_emb(
173
+ q: torch.Tensor,
174
+ k: torch.Tensor,
175
+ cos: torch.Tensor,
176
+ sin: torch.Tensor,
177
+ position_ids: Optional[torch.Tensor] = None,
178
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
179
+ if position_ids is None:
180
+ cos = cos.unsqueeze(0).unsqueeze(0)
181
+ sin = sin.unsqueeze(0).unsqueeze(0)
182
+ else:
183
+ cos = cos[position_ids].unsqueeze(1)
184
+ sin = sin[position_ids].unsqueeze(1)
185
+
186
+ return (
187
+ (q * cos) + (rotate_half(q) * sin),
188
+ (k * cos) + (rotate_half(k) * sin),
189
+ )
190
+
191
+
192
+ class GroupedQueryAttention(nn.Module):
193
+ def __init__(self, config: ArgonneConfig) -> None:
194
+ super().__init__()
195
+ self.hidden_size = config.hidden_size
196
+ self.num_heads = config.num_attention_heads
197
+ self.num_kv_heads = config.num_key_value_heads
198
+ self.head_dim = self.hidden_size // self.num_heads
199
+ self.num_key_value_groups = self.num_heads // self.num_kv_heads
200
+ self.sliding_window = config.sliding_window
201
+
202
+ self.q_proj = nn.Linear(
203
+ self.hidden_size,
204
+ self.num_heads * self.head_dim,
205
+ bias=config.attention_bias,
206
+ )
207
+ self.k_proj = nn.Linear(
208
+ self.hidden_size,
209
+ self.num_kv_heads * self.head_dim,
210
+ bias=config.attention_bias,
211
+ )
212
+ self.v_proj = nn.Linear(
213
+ self.hidden_size,
214
+ self.num_kv_heads * self.head_dim,
215
+ bias=config.attention_bias,
216
+ )
217
+ self.o_proj = nn.Linear(
218
+ self.num_heads * self.head_dim,
219
+ self.hidden_size,
220
+ bias=config.attention_bias,
221
+ )
222
+ self.o_proj._is_residual = True
223
+
224
+ self.attention_dropout = config.attention_dropout
225
+ self.use_flash_attention = config.use_flash_attention
226
+
227
+ def _repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
228
+ if self.num_key_value_groups == 1:
229
+ return x
230
+ bsz, num_kv, seqlen, head_dim = x.shape
231
+ x = x[:, :, None, :, :].expand(bsz, num_kv, self.num_key_value_groups, seqlen, head_dim)
232
+ return x.reshape(bsz, num_kv * self.num_key_value_groups, seqlen, head_dim)
233
+
234
+ def forward(
235
+ self,
236
+ hidden_states: torch.Tensor,
237
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
238
+ attention_mask: Optional[torch.Tensor] = None,
239
+ ) -> torch.Tensor:
240
+ bsz, seqlen, _ = hidden_states.shape
241
+
242
+ query = self.q_proj(hidden_states)
243
+ key = self.k_proj(hidden_states)
244
+ value = self.v_proj(hidden_states)
245
+
246
+ query = query.view(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2)
247
+ key = key.view(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2)
248
+ value = value.view(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2)
249
+
250
+ cos, sin = position_embeddings
251
+ query, key = apply_rotary_pos_emb(query, key, cos, sin)
252
+
253
+ key = self._repeat_kv(key)
254
+ value = self._repeat_kv(value)
255
+
256
+ use_flash_attn_2 = (
257
+ _flash_attn_available
258
+ and self.use_flash_attention
259
+ and attention_mask is None
260
+ and query.dtype in (torch.float16, torch.bfloat16)
261
+ and self.head_dim % 4 == 0
262
+ )
263
+ use_scaled_dot = (
264
+ hasattr(F, "scaled_dot_product_attention")
265
+ and self.use_flash_attention
266
+ and query.dtype in (torch.float16, torch.bfloat16)
267
+ and self.head_dim % 4 == 0
268
+ )
269
+
270
+ attn_output = None
271
+ if use_flash_attn_2:
272
+ try:
273
+ flash_dropout = self.attention_dropout if self.training else 0.0
274
+ window = (
275
+ (self.sliding_window, self.sliding_window)
276
+ if self.sliding_window is not None
277
+ else (-1, -1)
278
+ )
279
+ q = query.transpose(1, 2).contiguous()
280
+ k = key.transpose(1, 2).contiguous()
281
+ v = value.transpose(1, 2).contiguous()
282
+ attn_output = flash_attn_func(
283
+ q,
284
+ k,
285
+ v,
286
+ dropout_p=flash_dropout,
287
+ softmax_scale=None,
288
+ causal=True,
289
+ window_size=window,
290
+ ).transpose(1, 2)
291
+ except RuntimeError:
292
+ attn_output = None
293
+
294
+ if attn_output is None and use_scaled_dot:
295
+ try:
296
+ # Use is_causal=True when no attention_mask (faster Flash Attention path)
297
+ # When attention_mask is provided, we need to combine it with causal masking
298
+ if attention_mask is None:
299
+ attn_output = F.scaled_dot_product_attention(
300
+ query,
301
+ key,
302
+ value,
303
+ attn_mask=None,
304
+ dropout_p=self.attention_dropout if self.training else 0.0,
305
+ is_causal=True,
306
+ )
307
+ else:
308
+ # With attention_mask: need to pass it explicitly (slower but correct)
309
+ # attention_mask should be 4D: (bsz, 1, seq, seq) or broadcastable
310
+ attn_output = F.scaled_dot_product_attention(
311
+ query,
312
+ key,
313
+ value,
314
+ attn_mask=attention_mask,
315
+ dropout_p=self.attention_dropout if self.training else 0.0,
316
+ is_causal=False, # Mask already includes causal component
317
+ )
318
+ except RuntimeError:
319
+ # Fallback to math attention when kernels are unavailable
320
+ attn_output = None
321
+
322
+ if attn_output is None:
323
+ scores = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(self.head_dim)
324
+ # Apply causal mask - use large negative instead of -inf for numerical stability
325
+ causal_mask = torch.triu(
326
+ torch.ones(seqlen, seqlen, dtype=torch.bool, device=hidden_states.device),
327
+ diagonal=1,
328
+ )
329
+ mask_value = -65504.0 # Large negative instead of -inf
330
+ scores = scores.masked_fill(causal_mask, mask_value)
331
+ # Apply attention_mask if provided
332
+ if attention_mask is not None:
333
+ scores = scores + attention_mask
334
+ attn_weights = torch.softmax(scores, dim=-1, dtype=torch.float32).to(query.dtype)
335
+ attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
336
+ attn_output = torch.matmul(attn_weights, value)
337
+
338
+ attn_output = (
339
+ attn_output.transpose(1, 2)
340
+ .contiguous()
341
+ .view(bsz, seqlen, self.num_heads * self.head_dim)
342
+ )
343
+ return self.o_proj(attn_output)
344
+
345
+
346
+ class SwiGLUMLP(nn.Module):
347
+ def __init__(self, config: ArgonneConfig) -> None:
348
+ super().__init__()
349
+ self.gate_proj = nn.Linear(
350
+ config.hidden_size,
351
+ config.intermediate_size,
352
+ bias=config.mlp_bias,
353
+ )
354
+ self.up_proj = nn.Linear(
355
+ config.hidden_size,
356
+ config.intermediate_size,
357
+ bias=config.mlp_bias,
358
+ )
359
+ self.down_proj = nn.Linear(
360
+ config.intermediate_size,
361
+ config.hidden_size,
362
+ bias=config.mlp_bias,
363
+ )
364
+ self.down_proj._is_residual = True
365
+ self.dropout = nn.Dropout(config.hidden_dropout)
366
+
367
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
368
+ # Clamp intermediate values to prevent overflow
369
+ gate = self.gate_proj(x)
370
+ gate = torch.clamp(gate, min=-65504.0, max=65504.0)
371
+ up = self.up_proj(x)
372
+ up = torch.clamp(up, min=-65504.0, max=65504.0)
373
+ return self.dropout(self.down_proj(F.silu(gate) * up))
374
+
375
+
376
+ class Block(nn.Module):
377
+ """Transformer block with GQA attention and SwiGLU feed-forward."""
378
+
379
+ def __init__(self, config: ArgonneConfig, layer_idx: int = 0) -> None:
380
+ super().__init__()
381
+ self.layer_idx = layer_idx
382
+ self.attn = GroupedQueryAttention(config)
383
+ self.input_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
384
+ self.post_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
385
+ self.mlp = SwiGLUMLP(config)
386
+
387
+ def forward(
388
+ self,
389
+ hidden_states: torch.Tensor,
390
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
391
+ attention_mask: Optional[torch.Tensor] = None,
392
+ ) -> torch.Tensor:
393
+ residual = hidden_states
394
+ hidden_states = self.input_norm(hidden_states)
395
+ hidden_states = self.attn(hidden_states, position_embeddings, attention_mask)
396
+ hidden_states = residual + hidden_states
397
+
398
+ residual = hidden_states
399
+ hidden_states = self.post_norm(hidden_states)
400
+ hidden_states = self.mlp(hidden_states)
401
+ hidden_states = residual + hidden_states
402
+
403
+ return hidden_states
404
+
405
+
406
+ class ArgonneModel(PreTrainedModel):
407
+ config_class = ArgonneConfig
408
+ _no_split_modules = ["Block"]
409
+ _tied_weights_keys = ["lm_head.weight"]
410
+
411
+ def __init__(self, config: ArgonneConfig) -> None:
412
+ super().__init__(config)
413
+ self.config = config
414
+
415
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
416
+ self.blocks = nn.ModuleList([Block(config, idx) for idx in range(config.num_hidden_layers)])
417
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
418
+ self.rotary_emb = RotaryEmbedding(
419
+ config.hidden_size // config.num_attention_heads,
420
+ max_position_embeddings=config.max_position_embeddings,
421
+ base=config.rope_theta,
422
+ )
423
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
424
+
425
+ if config.tie_word_embeddings:
426
+ self.lm_head.weight = self.embed_tokens.weight
427
+
428
+ self.gradient_checkpointing = config.use_gradient_checkpointing
429
+ self.pipeline_partitions: Optional[List[Tuple[int, int, torch.device]]] = None
430
+ self.devices: List[torch.device] = []
431
+ self.output_device: torch.device = self.embed_tokens.weight.device
432
+ self.post_init()
433
+
434
+ def get_input_embeddings(self) -> nn.Embedding:
435
+ return self.embed_tokens
436
+
437
+ def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
438
+ self.embed_tokens = new_embeddings
439
+ self.config.vocab_size = new_embeddings.num_embeddings
440
+ if self.config.tie_word_embeddings:
441
+ self.lm_head.weight = self.embed_tokens.weight
442
+
443
+ def get_output_embeddings(self) -> nn.Module:
444
+ return self.lm_head
445
+
446
+ def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
447
+ self.lm_head = new_embeddings
448
+ if isinstance(new_embeddings, nn.Linear):
449
+ self.config.vocab_size = new_embeddings.out_features
450
+
451
+ def tie_weights(self) -> None:
452
+ if self.config.tie_word_embeddings:
453
+ self.lm_head.weight = self.embed_tokens.weight
454
+
455
+ def _init_weights(self, module: nn.Module) -> None:
456
+ if isinstance(module, nn.Linear):
457
+ std = self.config.hidden_size ** -0.5
458
+ if hasattr(module, "_is_residual"):
459
+ std = (2 * self.config.num_hidden_layers) ** -0.5
460
+ nn.init.normal_(module.weight, mean=0.0, std=std)
461
+ if module.bias is not None:
462
+ nn.init.zeros_(module.bias)
463
+ elif isinstance(module, nn.Embedding):
464
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.hidden_size ** -0.5)
465
+
466
+ def set_gradient_checkpointing(self, enabled: bool = True) -> None:
467
+ self.gradient_checkpointing = enabled
468
+
469
+ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None) -> None:
470
+ self.set_gradient_checkpointing(True)
471
+
472
+ def gradient_checkpointing_disable(self) -> None:
473
+ self.set_gradient_checkpointing(False)
474
+
475
+ def distribute_model(self, device_ids: Optional[List[str]] = None) -> None:
476
+ if device_ids is None:
477
+ num_gpus = torch.cuda.device_count()
478
+ if num_gpus < 1:
479
+ raise ValueError("No CUDA devices available for distribution.")
480
+ device_ids = [f"cuda:{i}" for i in range(num_gpus)]
481
+
482
+ if not device_ids:
483
+ raise ValueError("device_ids must contain at least one device identifier.")
484
+
485
+ self.devices = [torch.device(d) for d in device_ids]
486
+ num_blocks = len(self.blocks)
487
+
488
+ if num_blocks == 0:
489
+ raise ValueError("The model has no transformer blocks to distribute.")
490
+
491
+ block_param_bytes: List[int] = []
492
+ for block in self.blocks:
493
+ size_bytes = 0
494
+ for param in block.parameters():
495
+ size_bytes += param.numel() * param.element_size()
496
+ block_param_bytes.append(size_bytes)
497
+
498
+ block_cumsum: List[int] = [0]
499
+ for size in block_param_bytes:
500
+ block_cumsum.append(block_cumsum[-1] + size)
501
+
502
+ embed_bytes = sum(p.numel() * p.element_size() for p in self.embed_tokens.parameters())
503
+ rotary_bytes = sum(p.numel() * p.element_size() for p in self.rotary_emb.parameters())
504
+ norm_bytes = sum(p.numel() * p.element_size() for p in self.norm.parameters())
505
+ head_dtype_size = self.embed_tokens.weight.element_size()
506
+ head_bytes = self.config.hidden_size * self.config.vocab_size * head_dtype_size
507
+ if self.config.tie_word_embeddings and len(self.devices) == 1:
508
+ head_bytes = 0
509
+
510
+ total_bytes = (
511
+ block_cumsum[-1]
512
+ + norm_bytes
513
+ + head_bytes
514
+ + embed_bytes
515
+ + rotary_bytes
516
+ )
517
+ per_device_target = total_bytes / len(self.devices)
518
+
519
+ per_device_counts: List[int] = [0] * len(self.devices)
520
+ prev_cut = 0
521
+ for idx, _ in enumerate(self.devices):
522
+ remaining_devices = len(self.devices) - idx
523
+ remaining_blocks = num_blocks - prev_cut
524
+ if remaining_blocks <= 0:
525
+ per_device_counts[idx] = 0
526
+ continue
527
+ if remaining_devices == 1:
528
+ cut = num_blocks
529
+ else:
530
+ reserve = max(0, min(remaining_devices - 1, remaining_blocks - 1))
531
+ max_cut = prev_cut + (remaining_blocks - reserve)
532
+ lo = prev_cut + 1
533
+ device_overhead = 0
534
+ if idx == 0:
535
+ device_overhead = embed_bytes + rotary_bytes
536
+ available_block_bytes = per_device_target - device_overhead
537
+ if available_block_bytes <= 0:
538
+ cut = lo
539
+ else:
540
+ target_total = block_cumsum[prev_cut] + available_block_bytes
541
+ cut = bisect_right(block_cumsum, target_total, lo=lo, hi=max_cut + 1) - 1
542
+ if cut < lo:
543
+ cut = lo
544
+ per_device_counts[idx] = cut - prev_cut
545
+ prev_cut = cut
546
+
547
+ def compute_device_block_bytes() -> List[int]:
548
+ device_block_bytes: List[int] = []
549
+ cursor = 0
550
+ first_partition_idx = next(
551
+ (i for i, count in enumerate(per_device_counts) if count > 0),
552
+ 0,
553
+ )
554
+ for idx, block_count in enumerate(per_device_counts):
555
+ if block_count <= 0:
556
+ device_block_bytes.append(0)
557
+ continue
558
+ next_cursor = min(cursor + block_count, num_blocks)
559
+ block_bytes = block_cumsum[next_cursor] - block_cumsum[cursor]
560
+ if idx == first_partition_idx:
561
+ block_bytes += embed_bytes + rotary_bytes
562
+ device_block_bytes.append(block_bytes)
563
+ cursor = next_cursor
564
+ if len(device_block_bytes) < len(self.devices):
565
+ device_block_bytes.extend(
566
+ [0] * (len(self.devices) - len(device_block_bytes))
567
+ )
568
+ return device_block_bytes
569
+
570
+ output_payload = norm_bytes + head_bytes
571
+
572
+ device_block_bytes = compute_device_block_bytes()
573
+ positive_indices = [i for i, count in enumerate(per_device_counts) if count > 0]
574
+ if positive_indices:
575
+ last_idx = positive_indices[-1]
576
+ while True:
577
+ if per_device_counts[last_idx] <= 1:
578
+ break
579
+ other_indices = positive_indices[:-1]
580
+ if not other_indices:
581
+ break
582
+ other_loads = [device_block_bytes[i] for i in other_indices]
583
+ max_other = max(other_loads) if other_loads else 0
584
+ if max_other == 0:
585
+ break
586
+ last_load_with_head = device_block_bytes[last_idx] + output_payload
587
+ if last_load_with_head <= max_other:
588
+ break
589
+ prev_idx = other_indices[-1]
590
+ if per_device_counts[prev_idx] <= 0:
591
+ break
592
+ per_device_counts[last_idx] -= 1
593
+ per_device_counts[prev_idx] += 1
594
+ device_block_bytes = compute_device_block_bytes()
595
+ positive_indices = [
596
+ i for i, count in enumerate(per_device_counts) if count > 0
597
+ ]
598
+ last_idx = positive_indices[-1]
599
+
600
+ device_block_bytes = compute_device_block_bytes()
601
+ positive_indices = [i for i, count in enumerate(per_device_counts) if count > 0]
602
+ last_active_idx = positive_indices[-1] if positive_indices else 0
603
+
604
+ partitions: List[Tuple[int, int, torch.device]] = []
605
+ start_idx = 0
606
+ for device, block_count in zip(self.devices, per_device_counts):
607
+ if block_count <= 0 or start_idx >= num_blocks:
608
+ continue
609
+ end_idx = min(start_idx + block_count, num_blocks)
610
+ for block in self.blocks[start_idx:end_idx]:
611
+ block.to(device)
612
+ partitions.append((start_idx, end_idx, device))
613
+ start_idx = end_idx
614
+
615
+ if not partitions:
616
+ partitions.append((0, num_blocks, self.devices[0]))
617
+ if per_device_counts:
618
+ per_device_counts[0] = num_blocks
619
+ if not device_block_bytes:
620
+ device_block_bytes.append(block_cumsum[num_blocks])
621
+ if not device_block_bytes:
622
+ device_block_bytes = [block_cumsum[num_blocks]]
623
+
624
+ self.pipeline_partitions = partitions
625
+ self.output_device = partitions[-1][2]
626
+ output_device_idx = last_active_idx
627
+
628
+ first_device = partitions[0][2]
629
+ self.embed_tokens = self.embed_tokens.to(first_device)
630
+ self.rotary_emb = self.rotary_emb.to(first_device)
631
+ self.norm = self.norm.to(self.output_device)
632
+
633
+ if self.config.tie_word_embeddings and len(self.devices) > 1:
634
+ untied_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
635
+ untied_head.to(self.output_device)
636
+ with torch.no_grad():
637
+ untied_head.weight.copy_(self.embed_tokens.weight.to(self.output_device))
638
+ self.lm_head = untied_head
639
+ self.config.tie_word_embeddings = False
640
+ else:
641
+ self.lm_head = self.lm_head.to(self.output_device)
642
+
643
+ print(f"Model distributed across {len(self.devices)} devices.")
644
+ running = 0
645
+ for idx, (block_count, device) in enumerate(zip(per_device_counts, self.devices)):
646
+ if block_count <= 0:
647
+ print(f" Stage {idx}: no transformer blocks on {device}")
648
+ continue
649
+ start = running
650
+ end = start + block_count
651
+ running = end
652
+ print(f" Stage {idx}: layers {start}-{end - 1} on {device}")
653
+ estimated_gb = device_block_bytes[idx] / (1024 ** 3)
654
+ print(f" ≈{estimated_gb:.2f} GB of parameters")
655
+ print(
656
+ " Final RMSNorm and LM head on "
657
+ f"{self.output_device} (stage {output_device_idx})"
658
+ )
659
+ output_gb = (device_block_bytes[output_device_idx] + norm_bytes + head_bytes) / (
660
+ 1024 ** 3
661
+ )
662
+ print(f" Estimated post-head load: ≈{output_gb:.2f} GB")
663
+
664
+ def _prepare_attention_mask(
665
+ self,
666
+ attention_mask: Optional[torch.Tensor],
667
+ batch_size: int,
668
+ seq_length: int,
669
+ device: torch.device,
670
+ dtype: torch.dtype,
671
+ ) -> Optional[torch.Tensor]:
672
+ """Prepare 4D attention mask from 2D mask (batch_size, seq_length).
673
+
674
+ Returns a 4D mask suitable for scaled_dot_product_attention.
675
+ The mask should be additive (0 for attend, -inf for mask out).
676
+ """
677
+ if attention_mask is None:
678
+ return None
679
+
680
+ # Convert 2D mask to 4D: (batch_size, seq_length) -> (batch_size, 1, seq_length, seq_length)
681
+ # Create causal mask
682
+ causal_mask = torch.triu(
683
+ torch.ones(seq_length, seq_length, dtype=torch.bool, device=device),
684
+ diagonal=1,
685
+ )
686
+
687
+ # Expand attention_mask from (batch, seq) to (batch, 1, 1, seq)
688
+ expanded_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_length, seq_length)
689
+
690
+ # Combine: positions that are either causally masked OR padding should be masked
691
+ # attention_mask is 1 for attend, 0 for mask -> invert for additive mask
692
+ # Use a large negative value instead of -inf to avoid numerical issues in bfloat16
693
+ # -65504 is approximately the most negative value representable in float16
694
+ # Using a more conservative value for numerical stability
695
+ min_dtype = torch.finfo(dtype).min if dtype.is_floating_point else -1e9
696
+ mask_value = max(min_dtype, -65504.0) # Clamp to avoid true -inf
697
+
698
+ combined_mask = torch.where(
699
+ causal_mask | (expanded_mask == 0),
700
+ torch.tensor(mask_value, dtype=dtype, device=device),
701
+ torch.tensor(0.0, dtype=dtype, device=device),
702
+ )
703
+
704
+ return combined_mask
705
+
706
+ def forward(
707
+ self,
708
+ input_ids: torch.LongTensor,
709
+ attention_mask: Optional[torch.Tensor] = None,
710
+ labels: Optional[torch.LongTensor] = None,
711
+ position_ids: Optional[torch.LongTensor] = None,
712
+ **kwargs, # Accept extra args from newer transformers (e.g., num_items_in_batch)
713
+ ) -> CausalLMOutput:
714
+ batch_size, seq_length = input_ids.shape
715
+
716
+ if self.pipeline_partitions:
717
+ first_device = self.pipeline_partitions[0][2]
718
+ hidden_states = self.embed_tokens(input_ids.to(first_device))
719
+
720
+ # Prepare 4D attention mask
721
+ if attention_mask is not None:
722
+ attention_mask = self._prepare_attention_mask(
723
+ attention_mask.to(first_device),
724
+ batch_size,
725
+ seq_length,
726
+ first_device,
727
+ hidden_states.dtype,
728
+ )
729
+
730
+ cos, sin = self.rotary_emb(hidden_states, seq_length)
731
+
732
+ for start, end, device in self.pipeline_partitions:
733
+ if hidden_states.device != device:
734
+ hidden_states = hidden_states.to(device)
735
+ rotary = (cos.to(device), sin.to(device))
736
+ attn_mask = attention_mask.to(device) if attention_mask is not None else None
737
+
738
+ for layer in self.blocks[start:end]:
739
+ if self.gradient_checkpointing and self.training:
740
+ hidden_states = torch.utils.checkpoint.checkpoint(
741
+ layer,
742
+ hidden_states,
743
+ rotary,
744
+ attn_mask,
745
+ use_reentrant=False,
746
+ )
747
+ else:
748
+ hidden_states = layer(hidden_states, rotary, attn_mask)
749
+
750
+ hidden_states = hidden_states.to(self.output_device)
751
+ else:
752
+ device = self.embed_tokens.weight.device
753
+ hidden_states = self.embed_tokens(input_ids.to(device))
754
+
755
+ # Prepare 4D attention mask
756
+ if attention_mask is not None:
757
+ attention_mask = self._prepare_attention_mask(
758
+ attention_mask.to(device),
759
+ batch_size,
760
+ seq_length,
761
+ device,
762
+ hidden_states.dtype,
763
+ )
764
+
765
+ cos, sin = self.rotary_emb(hidden_states, seq_length)
766
+ rotary = (cos, sin)
767
+
768
+ for layer in self.blocks:
769
+ if self.gradient_checkpointing and self.training:
770
+ hidden_states = torch.utils.checkpoint.checkpoint(
771
+ layer,
772
+ hidden_states,
773
+ rotary,
774
+ attention_mask,
775
+ use_reentrant=False,
776
+ )
777
+ else:
778
+ hidden_states = layer(hidden_states, rotary, attention_mask)
779
+
780
+ hidden_states = self.norm(hidden_states)
781
+ logits = self.lm_head(hidden_states)
782
+
783
+ # Check for NaN in logits and handle gracefully
784
+ if torch.isnan(logits).any():
785
+ # Replace NaN with zeros to prevent cascading failures
786
+ logits = torch.nan_to_num(logits, nan=0.0, posinf=65504.0, neginf=-65504.0)
787
+
788
+ loss = None
789
+ if labels is not None:
790
+ shift_logits = logits[..., :-1, :].contiguous()
791
+ shift_labels = labels[..., 1:].contiguous()
792
+ if shift_labels.device != shift_logits.device:
793
+ shift_labels = shift_labels.to(shift_logits.device)
794
+ loss = F.cross_entropy(
795
+ shift_logits.view(-1, shift_logits.size(-1)),
796
+ shift_labels.view(-1),
797
+ ignore_index=-100,
798
+ )
799
+ # Handle NaN loss
800
+ if torch.isnan(loss):
801
+ loss = torch.tensor(0.0, device=loss.device, dtype=loss.dtype, requires_grad=True)
802
+
803
+ return CausalLMOutput(logits=logits, loss=loss)
804
+
805
+ @torch.no_grad()
806
+ def generate(
807
+ self,
808
+ input_ids: torch.Tensor,
809
+ max_length: int = 1024,
810
+ temperature: float = 1.0,
811
+ top_k: Optional[int] = None,
812
+ top_p: Optional[float] = None,
813
+ do_sample: bool = True,
814
+ ) -> torch.Tensor:
815
+ self.eval()
816
+ device = self.pipeline_partitions[0][2] if self.pipeline_partitions else self.embed_tokens.weight.device
817
+ input_ids = input_ids.to(device)
818
+ while input_ids.shape[1] < max_length:
819
+ chunk = input_ids[:, -self.config.max_position_embeddings :]
820
+ outputs = self.forward(chunk)
821
+ logits = outputs.logits[:, -1, :] / temperature
822
+
823
+ if do_sample:
824
+ if top_k is not None:
825
+ top_values, _ = torch.topk(logits, min(top_k, logits.size(-1)))
826
+ logits = logits.masked_fill(logits < top_values[:, [-1]], float("-inf"))
827
+ if top_p is not None:
828
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
829
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
830
+ sorted_indices_to_remove = cumulative_probs > top_p
831
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
832
+ sorted_indices_to_remove[..., 0] = 0
833
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
834
+ logits = logits.masked_fill(indices_to_remove, float("-inf"))
835
+ probs = F.softmax(logits, dim=-1)
836
+ next_token = torch.multinomial(probs, num_samples=1)
837
+ else:
838
+ next_token = torch.argmax(logits, dim=-1, keepdim=True)
839
+
840
+ input_ids = torch.cat([input_ids, next_token.to(input_ids.device)], dim=-1)
841
+ if input_ids.shape[1] >= max_length:
842
+ break
843
+ return input_ids.to(device)
844
+
845
+
846
+ AutoConfig.register("argonne2", ArgonneConfig)
847
+ AutoModel.register(ArgonneConfig, ArgonneModel)
848
+ AutoModelForCausalLM.register(ArgonneConfig, ArgonneModel)
849
+
850
+ # Backwards compatibility exports
851
+ CausalSelfAttention = GroupedQueryAttention
852
+ MLP = SwiGLUMLP
model.safetensors.index.json ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 22147464000
4
+ },
5
+ "weight_map": {
6
+ "embed_tokens.weight": "model-00001-of-00005.safetensors",
7
+ "lm_head.weight": "model-00001-of-00005.safetensors",
8
+ "blocks.0.mlp.gate_proj.weight": "model-00001-of-00005.safetensors",
9
+ "blocks.0.mlp.up_proj.weight": "model-00001-of-00005.safetensors",
10
+ "blocks.0.mlp.down_proj.weight": "model-00002-of-00005.safetensors",
11
+ "blocks.1.mlp.gate_proj.weight": "model-00002-of-00005.safetensors",
12
+ "blocks.1.mlp.up_proj.weight": "model-00002-of-00005.safetensors",
13
+ "blocks.1.mlp.down_proj.weight": "model-00002-of-00005.safetensors",
14
+ "blocks.2.mlp.gate_proj.weight": "model-00002-of-00005.safetensors",
15
+ "blocks.2.mlp.up_proj.weight": "model-00002-of-00005.safetensors",
16
+ "blocks.2.mlp.down_proj.weight": "model-00002-of-00005.safetensors",
17
+ "blocks.3.mlp.gate_proj.weight": "model-00002-of-00005.safetensors",
18
+ "blocks.3.mlp.up_proj.weight": "model-00002-of-00005.safetensors",
19
+ "blocks.3.mlp.down_proj.weight": "model-00002-of-00005.safetensors",
20
+ "blocks.4.mlp.gate_proj.weight": "model-00002-of-00005.safetensors",
21
+ "blocks.4.mlp.up_proj.weight": "model-00002-of-00005.safetensors",
22
+ "blocks.4.mlp.down_proj.weight": "model-00002-of-00005.safetensors",
23
+ "blocks.5.mlp.gate_proj.weight": "model-00002-of-00005.safetensors",
24
+ "blocks.5.mlp.up_proj.weight": "model-00002-of-00005.safetensors",
25
+ "blocks.5.mlp.down_proj.weight": "model-00002-of-00005.safetensors",
26
+ "blocks.6.mlp.gate_proj.weight": "model-00002-of-00005.safetensors",
27
+ "blocks.6.mlp.up_proj.weight": "model-00002-of-00005.safetensors",
28
+ "blocks.6.mlp.down_proj.weight": "model-00002-of-00005.safetensors",
29
+ "blocks.7.mlp.gate_proj.weight": "model-00002-of-00005.safetensors",
30
+ "blocks.7.mlp.up_proj.weight": "model-00002-of-00005.safetensors",
31
+ "blocks.7.mlp.down_proj.weight": "model-00002-of-00005.safetensors",
32
+ "blocks.8.mlp.gate_proj.weight": "model-00002-of-00005.safetensors",
33
+ "blocks.8.mlp.up_proj.weight": "model-00002-of-00005.safetensors",
34
+ "blocks.8.mlp.down_proj.weight": "model-00002-of-00005.safetensors",
35
+ "blocks.9.mlp.gate_proj.weight": "model-00002-of-00005.safetensors",
36
+ "blocks.9.mlp.up_proj.weight": "model-00002-of-00005.safetensors",
37
+ "blocks.9.mlp.down_proj.weight": "model-00002-of-00005.safetensors",
38
+ "blocks.10.mlp.gate_proj.weight": "model-00002-of-00005.safetensors",
39
+ "blocks.10.mlp.up_proj.weight": "model-00003-of-00005.safetensors",
40
+ "blocks.10.mlp.down_proj.weight": "model-00003-of-00005.safetensors",
41
+ "blocks.11.mlp.gate_proj.weight": "model-00003-of-00005.safetensors",
42
+ "blocks.11.mlp.up_proj.weight": "model-00003-of-00005.safetensors",
43
+ "blocks.11.mlp.down_proj.weight": "model-00003-of-00005.safetensors",
44
+ "blocks.12.mlp.gate_proj.weight": "model-00003-of-00005.safetensors",
45
+ "blocks.12.mlp.up_proj.weight": "model-00003-of-00005.safetensors",
46
+ "blocks.12.mlp.down_proj.weight": "model-00003-of-00005.safetensors",
47
+ "blocks.13.mlp.gate_proj.weight": "model-00003-of-00005.safetensors",
48
+ "blocks.13.mlp.up_proj.weight": "model-00003-of-00005.safetensors",
49
+ "blocks.13.mlp.down_proj.weight": "model-00003-of-00005.safetensors",
50
+ "blocks.14.mlp.gate_proj.weight": "model-00003-of-00005.safetensors",
51
+ "blocks.14.mlp.up_proj.weight": "model-00003-of-00005.safetensors",
52
+ "blocks.14.mlp.down_proj.weight": "model-00003-of-00005.safetensors",
53
+ "blocks.15.mlp.gate_proj.weight": "model-00003-of-00005.safetensors",
54
+ "blocks.15.mlp.up_proj.weight": "model-00003-of-00005.safetensors",
55
+ "blocks.15.mlp.down_proj.weight": "model-00003-of-00005.safetensors",
56
+ "blocks.16.mlp.gate_proj.weight": "model-00003-of-00005.safetensors",
57
+ "blocks.16.mlp.up_proj.weight": "model-00003-of-00005.safetensors",
58
+ "blocks.16.mlp.down_proj.weight": "model-00003-of-00005.safetensors",
59
+ "blocks.17.mlp.gate_proj.weight": "model-00003-of-00005.safetensors",
60
+ "blocks.17.mlp.up_proj.weight": "model-00003-of-00005.safetensors",
61
+ "blocks.17.mlp.down_proj.weight": "model-00003-of-00005.safetensors",
62
+ "blocks.18.mlp.gate_proj.weight": "model-00003-of-00005.safetensors",
63
+ "blocks.18.mlp.up_proj.weight": "model-00003-of-00005.safetensors",
64
+ "blocks.18.mlp.down_proj.weight": "model-00003-of-00005.safetensors",
65
+ "blocks.19.mlp.gate_proj.weight": "model-00003-of-00005.safetensors",
66
+ "blocks.19.mlp.up_proj.weight": "model-00003-of-00005.safetensors",
67
+ "blocks.19.mlp.down_proj.weight": "model-00003-of-00005.safetensors",
68
+ "blocks.20.mlp.gate_proj.weight": "model-00004-of-00005.safetensors",
69
+ "blocks.20.mlp.up_proj.weight": "model-00004-of-00005.safetensors",
70
+ "blocks.20.mlp.down_proj.weight": "model-00004-of-00005.safetensors",
71
+ "blocks.21.mlp.gate_proj.weight": "model-00004-of-00005.safetensors",
72
+ "blocks.21.mlp.up_proj.weight": "model-00004-of-00005.safetensors",
73
+ "blocks.21.mlp.down_proj.weight": "model-00004-of-00005.safetensors",
74
+ "blocks.22.mlp.gate_proj.weight": "model-00004-of-00005.safetensors",
75
+ "blocks.22.mlp.up_proj.weight": "model-00004-of-00005.safetensors",
76
+ "blocks.22.mlp.down_proj.weight": "model-00004-of-00005.safetensors",
77
+ "blocks.23.mlp.gate_proj.weight": "model-00004-of-00005.safetensors",
78
+ "blocks.23.mlp.up_proj.weight": "model-00004-of-00005.safetensors",
79
+ "blocks.23.mlp.down_proj.weight": "model-00004-of-00005.safetensors",
80
+ "blocks.0.attn.q_proj.weight": "model-00004-of-00005.safetensors",
81
+ "blocks.0.attn.o_proj.weight": "model-00004-of-00005.safetensors",
82
+ "blocks.1.attn.q_proj.weight": "model-00004-of-00005.safetensors",
83
+ "blocks.1.attn.o_proj.weight": "model-00004-of-00005.safetensors",
84
+ "blocks.2.attn.q_proj.weight": "model-00004-of-00005.safetensors",
85
+ "blocks.2.attn.o_proj.weight": "model-00004-of-00005.safetensors",
86
+ "blocks.3.attn.q_proj.weight": "model-00004-of-00005.safetensors",
87
+ "blocks.3.attn.o_proj.weight": "model-00004-of-00005.safetensors",
88
+ "blocks.4.attn.q_proj.weight": "model-00004-of-00005.safetensors",
89
+ "blocks.4.attn.o_proj.weight": "model-00004-of-00005.safetensors",
90
+ "blocks.5.attn.q_proj.weight": "model-00004-of-00005.safetensors",
91
+ "blocks.5.attn.o_proj.weight": "model-00004-of-00005.safetensors",
92
+ "blocks.6.attn.q_proj.weight": "model-00004-of-00005.safetensors",
93
+ "blocks.6.attn.o_proj.weight": "model-00004-of-00005.safetensors",
94
+ "blocks.7.attn.q_proj.weight": "model-00004-of-00005.safetensors",
95
+ "blocks.7.attn.o_proj.weight": "model-00004-of-00005.safetensors",
96
+ "blocks.8.attn.q_proj.weight": "model-00004-of-00005.safetensors",
97
+ "blocks.8.attn.o_proj.weight": "model-00004-of-00005.safetensors",
98
+ "blocks.9.attn.q_proj.weight": "model-00004-of-00005.safetensors",
99
+ "blocks.9.attn.o_proj.weight": "model-00004-of-00005.safetensors",
100
+ "blocks.10.attn.q_proj.weight": "model-00004-of-00005.safetensors",
101
+ "blocks.10.attn.o_proj.weight": "model-00004-of-00005.safetensors",
102
+ "blocks.11.attn.q_proj.weight": "model-00004-of-00005.safetensors",
103
+ "blocks.11.attn.o_proj.weight": "model-00004-of-00005.safetensors",
104
+ "blocks.12.attn.q_proj.weight": "model-00004-of-00005.safetensors",
105
+ "blocks.12.attn.o_proj.weight": "model-00004-of-00005.safetensors",
106
+ "blocks.13.attn.q_proj.weight": "model-00004-of-00005.safetensors",
107
+ "blocks.13.attn.o_proj.weight": "model-00004-of-00005.safetensors",
108
+ "blocks.14.attn.q_proj.weight": "model-00004-of-00005.safetensors",
109
+ "blocks.14.attn.o_proj.weight": "model-00004-of-00005.safetensors",
110
+ "blocks.15.attn.q_proj.weight": "model-00004-of-00005.safetensors",
111
+ "blocks.15.attn.o_proj.weight": "model-00004-of-00005.safetensors",
112
+ "blocks.16.attn.q_proj.weight": "model-00004-of-00005.safetensors",
113
+ "blocks.16.attn.o_proj.weight": "model-00004-of-00005.safetensors",
114
+ "blocks.17.attn.q_proj.weight": "model-00004-of-00005.safetensors",
115
+ "blocks.17.attn.o_proj.weight": "model-00004-of-00005.safetensors",
116
+ "blocks.18.attn.q_proj.weight": "model-00004-of-00005.safetensors",
117
+ "blocks.18.attn.o_proj.weight": "model-00004-of-00005.safetensors",
118
+ "blocks.19.attn.q_proj.weight": "model-00004-of-00005.safetensors",
119
+ "blocks.19.attn.o_proj.weight": "model-00004-of-00005.safetensors",
120
+ "blocks.20.attn.q_proj.weight": "model-00004-of-00005.safetensors",
121
+ "blocks.20.attn.o_proj.weight": "model-00004-of-00005.safetensors",
122
+ "blocks.21.attn.q_proj.weight": "model-00004-of-00005.safetensors",
123
+ "blocks.21.attn.o_proj.weight": "model-00004-of-00005.safetensors",
124
+ "blocks.22.attn.q_proj.weight": "model-00004-of-00005.safetensors",
125
+ "blocks.22.attn.o_proj.weight": "model-00004-of-00005.safetensors",
126
+ "blocks.23.attn.q_proj.weight": "model-00004-of-00005.safetensors",
127
+ "blocks.23.attn.o_proj.weight": "model-00004-of-00005.safetensors",
128
+ "blocks.0.attn.k_proj.weight": "model-00005-of-00005.safetensors",
129
+ "blocks.0.attn.v_proj.weight": "model-00005-of-00005.safetensors",
130
+ "blocks.1.attn.k_proj.weight": "model-00005-of-00005.safetensors",
131
+ "blocks.1.attn.v_proj.weight": "model-00005-of-00005.safetensors",
132
+ "blocks.2.attn.k_proj.weight": "model-00005-of-00005.safetensors",
133
+ "blocks.2.attn.v_proj.weight": "model-00005-of-00005.safetensors",
134
+ "blocks.3.attn.k_proj.weight": "model-00005-of-00005.safetensors",
135
+ "blocks.3.attn.v_proj.weight": "model-00005-of-00005.safetensors",
136
+ "blocks.4.attn.k_proj.weight": "model-00005-of-00005.safetensors",
137
+ "blocks.4.attn.v_proj.weight": "model-00005-of-00005.safetensors",
138
+ "blocks.5.attn.k_proj.weight": "model-00005-of-00005.safetensors",
139
+ "blocks.5.attn.v_proj.weight": "model-00005-of-00005.safetensors",
140
+ "blocks.6.attn.k_proj.weight": "model-00005-of-00005.safetensors",
141
+ "blocks.6.attn.v_proj.weight": "model-00005-of-00005.safetensors",
142
+ "blocks.7.attn.k_proj.weight": "model-00005-of-00005.safetensors",
143
+ "blocks.7.attn.v_proj.weight": "model-00005-of-00005.safetensors",
144
+ "blocks.8.attn.k_proj.weight": "model-00005-of-00005.safetensors",
145
+ "blocks.8.attn.v_proj.weight": "model-00005-of-00005.safetensors",
146
+ "blocks.9.attn.k_proj.weight": "model-00005-of-00005.safetensors",
147
+ "blocks.9.attn.v_proj.weight": "model-00005-of-00005.safetensors",
148
+ "blocks.10.attn.k_proj.weight": "model-00005-of-00005.safetensors",
149
+ "blocks.10.attn.v_proj.weight": "model-00005-of-00005.safetensors",
150
+ "blocks.11.attn.k_proj.weight": "model-00005-of-00005.safetensors",
151
+ "blocks.11.attn.v_proj.weight": "model-00005-of-00005.safetensors",
152
+ "blocks.12.attn.k_proj.weight": "model-00005-of-00005.safetensors",
153
+ "blocks.12.attn.v_proj.weight": "model-00005-of-00005.safetensors",
154
+ "blocks.13.attn.k_proj.weight": "model-00005-of-00005.safetensors",
155
+ "blocks.13.attn.v_proj.weight": "model-00005-of-00005.safetensors",
156
+ "blocks.14.attn.k_proj.weight": "model-00005-of-00005.safetensors",
157
+ "blocks.14.attn.v_proj.weight": "model-00005-of-00005.safetensors",
158
+ "blocks.15.attn.k_proj.weight": "model-00005-of-00005.safetensors",
159
+ "blocks.15.attn.v_proj.weight": "model-00005-of-00005.safetensors",
160
+ "blocks.16.attn.k_proj.weight": "model-00005-of-00005.safetensors",
161
+ "blocks.16.attn.v_proj.weight": "model-00005-of-00005.safetensors",
162
+ "blocks.17.attn.k_proj.weight": "model-00005-of-00005.safetensors",
163
+ "blocks.17.attn.v_proj.weight": "model-00005-of-00005.safetensors",
164
+ "blocks.18.attn.k_proj.weight": "model-00005-of-00005.safetensors",
165
+ "blocks.18.attn.v_proj.weight": "model-00005-of-00005.safetensors",
166
+ "blocks.19.attn.k_proj.weight": "model-00005-of-00005.safetensors",
167
+ "blocks.19.attn.v_proj.weight": "model-00005-of-00005.safetensors",
168
+ "blocks.20.attn.k_proj.weight": "model-00005-of-00005.safetensors",
169
+ "blocks.20.attn.v_proj.weight": "model-00005-of-00005.safetensors",
170
+ "blocks.21.attn.k_proj.weight": "model-00005-of-00005.safetensors",
171
+ "blocks.21.attn.v_proj.weight": "model-00005-of-00005.safetensors",
172
+ "blocks.22.attn.k_proj.weight": "model-00005-of-00005.safetensors",
173
+ "blocks.22.attn.v_proj.weight": "model-00005-of-00005.safetensors",
174
+ "blocks.23.attn.k_proj.weight": "model-00005-of-00005.safetensors",
175
+ "blocks.23.attn.v_proj.weight": "model-00005-of-00005.safetensors",
176
+ "blocks.0.input_norm.weight": "model-00005-of-00005.safetensors",
177
+ "blocks.0.post_norm.weight": "model-00005-of-00005.safetensors",
178
+ "blocks.1.input_norm.weight": "model-00005-of-00005.safetensors",
179
+ "blocks.1.post_norm.weight": "model-00005-of-00005.safetensors",
180
+ "blocks.2.input_norm.weight": "model-00005-of-00005.safetensors",
181
+ "blocks.2.post_norm.weight": "model-00005-of-00005.safetensors",
182
+ "blocks.3.input_norm.weight": "model-00005-of-00005.safetensors",
183
+ "blocks.3.post_norm.weight": "model-00005-of-00005.safetensors",
184
+ "blocks.4.input_norm.weight": "model-00005-of-00005.safetensors",
185
+ "blocks.4.post_norm.weight": "model-00005-of-00005.safetensors",
186
+ "blocks.5.input_norm.weight": "model-00005-of-00005.safetensors",
187
+ "blocks.5.post_norm.weight": "model-00005-of-00005.safetensors",
188
+ "blocks.6.input_norm.weight": "model-00005-of-00005.safetensors",
189
+ "blocks.6.post_norm.weight": "model-00005-of-00005.safetensors",
190
+ "blocks.7.input_norm.weight": "model-00005-of-00005.safetensors",
191
+ "blocks.7.post_norm.weight": "model-00005-of-00005.safetensors",
192
+ "blocks.8.input_norm.weight": "model-00005-of-00005.safetensors",
193
+ "blocks.8.post_norm.weight": "model-00005-of-00005.safetensors",
194
+ "blocks.9.input_norm.weight": "model-00005-of-00005.safetensors",
195
+ "blocks.9.post_norm.weight": "model-00005-of-00005.safetensors",
196
+ "blocks.10.input_norm.weight": "model-00005-of-00005.safetensors",
197
+ "blocks.10.post_norm.weight": "model-00005-of-00005.safetensors",
198
+ "blocks.11.input_norm.weight": "model-00005-of-00005.safetensors",
199
+ "blocks.11.post_norm.weight": "model-00005-of-00005.safetensors",
200
+ "blocks.12.input_norm.weight": "model-00005-of-00005.safetensors",
201
+ "blocks.12.post_norm.weight": "model-00005-of-00005.safetensors",
202
+ "blocks.13.input_norm.weight": "model-00005-of-00005.safetensors",
203
+ "blocks.13.post_norm.weight": "model-00005-of-00005.safetensors",
204
+ "blocks.14.input_norm.weight": "model-00005-of-00005.safetensors",
205
+ "blocks.14.post_norm.weight": "model-00005-of-00005.safetensors",
206
+ "blocks.15.input_norm.weight": "model-00005-of-00005.safetensors",
207
+ "blocks.15.post_norm.weight": "model-00005-of-00005.safetensors",
208
+ "blocks.16.input_norm.weight": "model-00005-of-00005.safetensors",
209
+ "blocks.16.post_norm.weight": "model-00005-of-00005.safetensors",
210
+ "blocks.17.input_norm.weight": "model-00005-of-00005.safetensors",
211
+ "blocks.17.post_norm.weight": "model-00005-of-00005.safetensors",
212
+ "blocks.18.input_norm.weight": "model-00005-of-00005.safetensors",
213
+ "blocks.18.post_norm.weight": "model-00005-of-00005.safetensors",
214
+ "blocks.19.input_norm.weight": "model-00005-of-00005.safetensors",
215
+ "blocks.19.post_norm.weight": "model-00005-of-00005.safetensors",
216
+ "blocks.20.input_norm.weight": "model-00005-of-00005.safetensors",
217
+ "blocks.20.post_norm.weight": "model-00005-of-00005.safetensors",
218
+ "blocks.21.input_norm.weight": "model-00005-of-00005.safetensors",
219
+ "blocks.21.post_norm.weight": "model-00005-of-00005.safetensors",
220
+ "blocks.22.input_norm.weight": "model-00005-of-00005.safetensors",
221
+ "blocks.22.post_norm.weight": "model-00005-of-00005.safetensors",
222
+ "blocks.23.input_norm.weight": "model-00005-of-00005.safetensors",
223
+ "blocks.23.post_norm.weight": "model-00005-of-00005.safetensors",
224
+ "norm.weight": "model-00005-of-00005.safetensors"
225
+ }
226
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ }
181
+ },
182
+ "additional_special_tokens": [
183
+ "<|im_start|>",
184
+ "<|im_end|>",
185
+ "<|object_ref_start|>",
186
+ "<|object_ref_end|>",
187
+ "<|box_start|>",
188
+ "<|box_end|>",
189
+ "<|quad_start|>",
190
+ "<|quad_end|>",
191
+ "<|vision_start|>",
192
+ "<|vision_end|>",
193
+ "<|vision_pad|>",
194
+ "<|image_pad|>",
195
+ "<|video_pad|>"
196
+ ],
197
+ "bos_token": null,
198
+ "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
199
+ "clean_up_tokenization_spaces": false,
200
+ "eos_token": "<|im_end|>",
201
+ "errors": "replace",
202
+ "model_max_length": 1000000000,
203
+ "pad_token": "<|endoftext|>",
204
+ "split_special_tokens": false,
205
+ "tokenizer_class": "Qwen2Tokenizer",
206
+ "unk_token": null
207
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff