Youzhi Yu commited on
Commit
4395cf9
·
0 Parent(s):

Initial commit of Argonne-1.5 model

Browse files
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.bin filter=lfs diff=lfs merge=lfs -text
config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "argonne",
3
+ "architectures": [
4
+ "ArgonneModel"
5
+ ],
6
+ "block_size": 2048,
7
+ "dropout": 0.1,
8
+ "n_embd": 1296,
9
+ "n_head": 16,
10
+ "n_layer": 16,
11
+ "use_flash_attn": true,
12
+ "vocab_size": 12000,
13
+ "torch_dtype": "float16",
14
+ "transformers_version": "4.44.0",
15
+ "auto_map": {
16
+ "AutoConfig": "model.ArgonneConfig",
17
+ "AutoModel": "model.ArgonneModel",
18
+ "AutoModelForCausalLM": "model.ArgonneModel"
19
+ }
20
+ }
21
+
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.44.0"
4
+ }
model.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from transformers import (
6
+ PretrainedConfig,
7
+ PreTrainedModel,
8
+ AutoConfig,
9
+ AutoModel,
10
+ AutoModelForCausalLM
11
+ )
12
+
13
+ class ArgonneConfig(PretrainedConfig):
14
+ model_type = "argonne"
15
+ def __init__(self, vocab_size=12000, block_size=2048, n_layer=24, n_head=24, n_embd=1296, dropout=0.1, use_flash_attn=True, **kwargs):
16
+ super().__init__(**kwargs)
17
+ self.vocab_size = vocab_size
18
+ self.block_size = block_size
19
+ self.n_layer = n_layer
20
+ self.n_head = n_head
21
+ self.n_embd = n_embd
22
+ self.dropout = dropout
23
+ self.use_flash_attn = use_flash_attn
24
+
25
+ class Block(nn.Module):
26
+ def __init__(self, config):
27
+ super().__init__()
28
+ self.ln1 = nn.LayerNorm(config.n_embd)
29
+ self.attn = CausalSelfAttention(config)
30
+ self.ln2 = nn.LayerNorm(config.n_embd)
31
+ self.mlp = MLP(config)
32
+ def forward(self, x):
33
+ x = x + self.attn(self.ln1(x))
34
+ x = x + self.mlp(self.ln2(x))
35
+ return x
36
+
37
+ class CausalSelfAttention(nn.Module):
38
+ def __init__(self, config):
39
+ super().__init__()
40
+ assert config.n_embd % config.n_head == 0, "Embedding dim must be divisible by n_head"
41
+ self.n_head = config.n_head
42
+ self.head_dim = config.n_embd // config.n_head
43
+ self.query = nn.Linear(config.n_embd, config.n_embd)
44
+ self.key = nn.Linear(config.n_embd, config.n_embd)
45
+ self.value = nn.Linear(config.n_embd, config.n_embd)
46
+ self.attn_drop = nn.Dropout(config.dropout)
47
+ self.resid_drop = nn.Dropout(config.dropout)
48
+ self.proj = nn.Linear(config.n_embd, config.n_embd)
49
+ self.use_flash_attn = getattr(config, 'use_flash_attn', True)
50
+
51
+ # Register the causal mask for the traditional attention path
52
+ self.register_buffer(
53
+ "mask",
54
+ torch.tril(torch.ones(config.block_size, config.block_size))
55
+ .view(1, 1, config.block_size, config.block_size)
56
+ )
57
+
58
+ def forward(self, x):
59
+ b, t, c = x.size()
60
+ q = self.query(x).view(b, t, self.n_head, self.head_dim).transpose(1, 2)
61
+ k = self.key(x).view(b, t, self.n_head, self.head_dim).transpose(1, 2)
62
+ v = self.value(x).view(b, t, self.n_head, self.head_dim).transpose(1, 2)
63
+
64
+ if hasattr(F, 'scaled_dot_product_attention') and self.use_flash_attn:
65
+ # When using is_causal=True, don't provide an attention mask
66
+ attn_output = F.scaled_dot_product_attention(
67
+ q, k, v,
68
+ dropout_p=self.attn_drop.p if self.training else 0.0,
69
+ is_causal=True # Let PyTorch handle the causal mask internally
70
+ )
71
+ attn_output = attn_output.transpose(1, 2).contiguous().view(b, t, c)
72
+ y = self.resid_drop(self.proj(attn_output))
73
+ return y
74
+ else:
75
+ # Original attention implementation (fallback)
76
+ att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
77
+ att = att.masked_fill(self.mask[:, :, :t, :t] == 0, float('-inf'))
78
+ att = torch.softmax(att, dim=-1)
79
+ att = self.attn_drop(att)
80
+ y = att @ v
81
+ y = y.transpose(1, 2).contiguous().view(b, t, c)
82
+ y = self.resid_drop(self.proj(y))
83
+ return y
84
+
85
+ class MLP(nn.Module):
86
+ def __init__(self, config):
87
+ super().__init__()
88
+ self.fc1 = nn.Linear(config.n_embd, 4 * config.n_embd)
89
+ self.act = nn.GELU()
90
+ self.fc2 = nn.Linear(4 * config.n_embd, config.n_embd)
91
+ self.drop = nn.Dropout(config.dropout)
92
+ def forward(self, x):
93
+ x = self.fc1(x)
94
+ x = self.act(x)
95
+ x = self.drop(x)
96
+ x = self.fc2(x)
97
+ x = self.drop(x)
98
+ return x
99
+
100
+ class ArgonneModel(PreTrainedModel):
101
+ config_class = ArgonneConfig
102
+
103
+ def __init__(self, config, device_map=None):
104
+ super().__init__(config)
105
+ # Create embeddings on CPU initially
106
+ self.token_embedding = nn.Embedding(config.vocab_size, config.n_embd)
107
+ self.position_embedding = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
108
+ self.drop = nn.Dropout(config.dropout)
109
+
110
+ # Build all blocks
111
+ self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
112
+
113
+ # Final LayerNorm + output head
114
+ self.ln_f = nn.LayerNorm(config.n_embd)
115
+ self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
116
+
117
+ nn.init.normal_(self.position_embedding, mean=0.0, std=0.02)
118
+ self.post_init()
119
+
120
+ # For pipeline parallelism
121
+ self.pipeline_stages = None
122
+ self.devices = []
123
+
124
+ # Handle device_map="auto" for inference
125
+ if device_map is not None:
126
+ self.setup_device_map(device_map)
127
+
128
+ def setup_device_map(self, device_map):
129
+ """
130
+ Set up the model on devices according to device_map.
131
+ If device_map="auto", use accelerate to automatically assign model parts to devices.
132
+ """
133
+ if device_map == "auto":
134
+ try:
135
+ from accelerate import dispatch_model
136
+ from accelerate.utils import infer_auto_device_map
137
+
138
+ # Get device map automatically
139
+ auto_device_map = infer_auto_device_map(self)
140
+ # Dispatch model across devices
141
+ dispatch_model(self, device_map=auto_device_map)
142
+
143
+ print(f"Model automatically distributed across devices with device_map: {auto_device_map}")
144
+
145
+ except ImportError:
146
+ print("The 'accelerate' library is required for device_map='auto'. Please install it with 'pip install accelerate'.")
147
+ print("Continuing with model on CPU or default device.")
148
+ else:
149
+ # Handle custom device map
150
+ # This would be a more complex implementation where the user provides a specific mapping
151
+ # of model components to devices
152
+ pass
153
+
154
+ def distribute_model(self, device_ids=None):
155
+ """
156
+ Distribute the model blocks across multiple GPU devices in a pipeline style.
157
+ If 'device_ids' is None, we'll discover all available GPUs.
158
+ """
159
+ if device_ids is None:
160
+ num_gpus = torch.cuda.device_count()
161
+ if num_gpus < 1:
162
+ raise ValueError("No GPUs found—can't do pipeline parallel on CPU only.")
163
+ device_ids = [f"cuda:{i}" for i in range(num_gpus)]
164
+
165
+ # Store them so the training loop can keep referencing model.devices
166
+ self.devices = [torch.device(d) for d in device_ids]
167
+
168
+ self.pipeline_stages = nn.ModuleList()
169
+ num_gpus = len(device_ids)
170
+ blocks_per_gpu = math.ceil(len(self.blocks) / num_gpus)
171
+
172
+ start_idx = 0
173
+ for i in range(num_gpus):
174
+ end_idx = min(start_idx + blocks_per_gpu, len(self.blocks))
175
+ stage_blocks = self.blocks[start_idx:end_idx]
176
+ stage = nn.Sequential(*stage_blocks).to(device_ids[i])
177
+ self.pipeline_stages.append(stage)
178
+ start_idx = end_idx
179
+ if end_idx >= len(self.blocks):
180
+ break
181
+
182
+ # Move embeddings to the first device
183
+ first_device = device_ids[0]
184
+ self.token_embedding = self.token_embedding.to(first_device)
185
+ # For nn.Parameter, we need to move the data, not replace the parameter
186
+ self.position_embedding.data = self.position_embedding.data.to(first_device)
187
+ self.drop = self.drop.to(first_device)
188
+
189
+ # Move final LayerNorm + head to the last device
190
+ last_device = device_ids[-1]
191
+ self.ln_f = self.ln_f.to(last_device)
192
+ self.head = self.head.to(last_device)
193
+
194
+ print(f"Model distributed across {len(device_ids)} devices")
195
+ print(f"First device: {first_device}, Last device: {last_device}")
196
+ print(f"Transformer layers per device: ~{blocks_per_gpu}")
197
+
198
+ def _init_weights(self, module):
199
+ if isinstance(module, nn.Linear):
200
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
201
+ if module.bias is not None:
202
+ nn.init.zeros_(module.bias)
203
+ elif isinstance(module, nn.Embedding):
204
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
205
+
206
+ def prepare_for_compile(self):
207
+ """
208
+ Prepare model for torch.compile() by ensuring all components
209
+ are compatible with the compiler.
210
+ """
211
+ # Some models may need special handling for compilation
212
+ # For now, we'll just return self since our model structure should be compatible
213
+ return self
214
+
215
+ def forward(self, idx, targets=None):
216
+ """
217
+ If self.pipeline_stages is None, we do a normal single-device forward
218
+ (whatever device everything is currently on—CPU or a single GPU).
219
+ Otherwise, we do a pipeline parallel forward.
220
+ """
221
+ # Make the forward method more compiler-friendly
222
+ if idx.dim() == 1:
223
+ # Add batch dimension if missing
224
+ idx = idx.unsqueeze(0)
225
+
226
+ # Rest of the forward method remains the same
227
+ if self.pipeline_stages is None:
228
+ # Single-device forward pass
229
+ device = self.token_embedding.weight.device
230
+ idx = idx.to(device)
231
+ b, t = idx.size()
232
+ assert t <= self.config.block_size, "Sequence length exceeds block size"
233
+
234
+ token_embeddings = self.token_embedding(idx)
235
+ position_embeddings = self.position_embedding[:, :t, :]
236
+ hidden_states = self.drop(token_embeddings + position_embeddings)
237
+
238
+ for block in self.blocks:
239
+ hidden_states = block(hidden_states)
240
+
241
+ hidden_states = self.ln_f(hidden_states)
242
+ logits = self.head(hidden_states)
243
+
244
+ loss = None
245
+ if targets is not None:
246
+ targets = targets.to(device)
247
+ logits = logits.view(-1, logits.size(-1))
248
+ targets = targets.view(-1)
249
+ loss = F.cross_entropy(logits, targets)
250
+
251
+ return logits, loss
252
+ else:
253
+ # Pipeline parallel forward
254
+ first_device = next(self.token_embedding.parameters()).device
255
+ last_device = next(self.ln_f.parameters()).device
256
+
257
+ x = idx.to(first_device)
258
+ b, t = x.size()
259
+ assert t <= self.config.block_size, "Sequence length exceeds block size"
260
+
261
+ token_embeddings = self.token_embedding(x)
262
+ position_embeddings = self.position_embedding[:, :t, :]
263
+ hidden_states = self.drop(token_embeddings + position_embeddings)
264
+
265
+ # Pass through each pipeline stage in sequence
266
+ for stage_idx, stage in enumerate(self.pipeline_stages):
267
+ device_stage = next(stage.parameters()).device
268
+ hidden_states = hidden_states.to(device_stage)
269
+ hidden_states = stage(hidden_states)
270
+
271
+ # Explicitly move to last device before final operations
272
+ hidden_states = hidden_states.to(last_device)
273
+ hidden_states = self.ln_f(hidden_states)
274
+ logits = self.head(hidden_states)
275
+
276
+ loss = None
277
+ if targets is not None:
278
+ targets = targets.to(last_device)
279
+ logits = logits.view(-1, logits.size(-1))
280
+ targets = targets.view(-1)
281
+ loss = F.cross_entropy(logits, targets)
282
+
283
+ return logits, loss
284
+
285
+ @torch.no_grad()
286
+ def generate(self, input_ids, max_new_tokens, temperature=0.7, top_k=None, top_p=None, sample=True):
287
+ """
288
+ Generate text using the model.
289
+
290
+ Args:
291
+ input_ids: Input token IDs to continue from
292
+ max_new_tokens: Number of tokens to generate
293
+ temperature: Temperature for sampling (higher = more random)
294
+ top_k: If set, only sample from the top k most likely tokens
295
+ top_p: If set, sample from the smallest set of tokens whose cumulative probability exceeds p
296
+ sample: If True, sample from the distribution; if False, use greedy decoding
297
+
298
+ Returns:
299
+ Tensor containing the input_ids extended with max_new_tokens generated tokens
300
+ """
301
+ self.eval()
302
+
303
+ # Determine which device to use - explicitly use first device for consistency
304
+ if self.pipeline_stages is not None and len(self.devices) > 0:
305
+ device = self.devices[0] # Always use first device for generation
306
+ else:
307
+ device = next(self.parameters()).device
308
+
309
+ # Ensure input is on the correct device
310
+ generated = input_ids.to(device)
311
+
312
+ for _ in range(max_new_tokens):
313
+ # Truncate if necessary to fit within the model's context window
314
+ if generated.shape[1] > self.config.block_size:
315
+ generated = generated[:, -self.config.block_size:]
316
+
317
+ # Forward pass
318
+ logits, _ = self.forward(generated)
319
+
320
+ # Make sure logits are on the same device
321
+ logits = logits.to(device)
322
+
323
+ # Get logits for the last token only
324
+ logits = logits[:, -1, :]
325
+
326
+ # Apply temperature
327
+ if temperature != 1.0:
328
+ logits = logits / temperature
329
+
330
+ # Greedy decoding (argmax) if sample=False
331
+ if not sample:
332
+ next_token = torch.argmax(logits, dim=-1, keepdim=True)
333
+ else:
334
+ # Sampling logic
335
+ # Apply top-k filtering
336
+ if top_k is not None:
337
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
338
+ logits = logits.masked_fill(indices_to_remove, float('-inf'))
339
+
340
+ # Apply top-p (nucleus) filtering
341
+ if top_p is not None:
342
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
343
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
344
+
345
+ # Remove tokens with cumulative probability above the threshold
346
+ sorted_indices_to_remove = cumulative_probs > top_p
347
+
348
+ # Shift the indices to the right to keep the first token above the threshold
349
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
350
+ sorted_indices_to_remove[..., 0] = 0
351
+
352
+ indices_to_remove = sorted_indices_to_remove.scatter(
353
+ dim=1, index=sorted_indices, src=sorted_indices_to_remove
354
+ )
355
+ logits = logits.masked_fill(indices_to_remove, float('-inf'))
356
+
357
+ # Convert to probability distribution and sample
358
+ probs = F.softmax(logits, dim=-1)
359
+ next_token = torch.multinomial(probs, num_samples=1)
360
+
361
+ # Ensure next_token is on the same device before concatenation
362
+ next_token = next_token.to(device)
363
+
364
+ # Append the generated token to the sequence
365
+ generated = torch.cat((generated, next_token), dim=1)
366
+
367
+ return generated
368
+
369
+ # Register the model with Hugging Face's Auto classes
370
+ AutoConfig.register("argonne", ArgonneConfig)
371
+ AutoModel.register(ArgonneConfig, ArgonneModel)
372
+ AutoModelForCausalLM.register(ArgonneConfig, ArgonneModel)
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fd21fa25e0b165ea52aec7972f05be959c82adb3d48980f73a71de042e28daa
3
+ size 847334947
special_tokens_map.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|start_of_text|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|end_of_text|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "mask_token": {
17
+ "content": "<mask>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "pad_token": {
24
+ "content": "<pad>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "unk_token": {
31
+ "content": "<unk>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ }
37
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<|start_of_text|>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<pad>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "<|end_of_text|>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<unk>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "4": {
36
+ "content": "<mask>",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "bos_token": "<|start_of_text|>",
45
+ "clean_up_tokenization_spaces": true,
46
+ "eos_token": "<|end_of_text|>",
47
+ "mask_token": "<mask>",
48
+ "model_max_length": 1000000000000000019884624838656,
49
+ "pad_token": "<pad>",
50
+ "tokenizer_class": "PreTrainedTokenizerFast",
51
+ "unk_token": "<unk>",
52
+ "use_fast": true
53
+ }