appledora commited on
Commit
4220e0c
·
verified ·
1 Parent(s): 572ef11

Upload modeling_recast_llama.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_recast_llama.py +880 -0
modeling_recast_llama.py ADDED
@@ -0,0 +1,880 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # filename: recastmlp_llama_model.py
2
+ from .configuration_recast_llama import RECAST8b_llama
3
+ from transformers import PreTrainedModel
4
+ import math
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from typing import Optional, Tuple, Union, List
9
+ from transformers import AutoConfig
10
+ from transformers.utils import logging
11
+ from transformers.cache_utils import Cache, StaticCache
12
+ from transformers.modeling_outputs import CausalLMOutputWithPast
13
+ from transformers.generation import GenerationMixin
14
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
15
+ from transformers.models.llama.modeling_llama import (
16
+ LlamaDecoderLayer,
17
+ LlamaRotaryEmbedding,
18
+ LlamaRMSNorm,
19
+ apply_rotary_pos_emb,
20
+ repeat_kv,
21
+ )
22
+ from transformers.modeling_outputs import BaseModelOutputWithPast
23
+ import copy
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ class MLPTemplateBank(nn.Module):
29
+ def __init__(self, config, coef_rows, coef_columns):
30
+ super().__init__()
31
+ self.hidden_size = config.hidden_size
32
+ self.intermediate_size = config.intermediate_size
33
+ self.coef_shape = (coef_rows, coef_columns)
34
+
35
+ assert coef_columns is not None, "coef_columns must not be None"
36
+
37
+ # Ensure divisibility for proper reshaping
38
+ assert (
39
+ self.hidden_size * self.intermediate_size
40
+ ) % coef_rows == 0, f"hidden_size * intermediate_size ({self.hidden_size * self.intermediate_size}) must be divisible by coef_rows ({coef_rows})"
41
+
42
+ template_size = self.hidden_size * self.intermediate_size // coef_rows
43
+
44
+ self.up_templates = nn.Parameter(torch.randn(coef_columns, template_size))
45
+ self.gate_templates = nn.Parameter(torch.randn(coef_columns, template_size))
46
+
47
+ # Better initialization
48
+ nn.init.xavier_uniform_(self.up_templates)
49
+ nn.init.xavier_uniform_(self.gate_templates)
50
+
51
+ def forward(self, up_coeffs, gate_coeffs):
52
+ # Compute chunked weights
53
+ up_chunks = torch.matmul(up_coeffs, self.up_templates)
54
+ gate_chunks = torch.matmul(gate_coeffs, self.gate_templates)
55
+
56
+ # Reshape to final weight matrices
57
+ up_weights = up_chunks.reshape(self.intermediate_size, self.hidden_size)
58
+ gate_weights = gate_chunks.reshape(self.intermediate_size, self.hidden_size)
59
+
60
+ return up_weights, gate_weights
61
+
62
+
63
+ class SharedLlamaMLP(nn.Module):
64
+ def __init__(self, config, bank):
65
+ super().__init__()
66
+ self.config = config
67
+ self.bank = bank
68
+ self.hidden_size = config.hidden_size
69
+ self.intermediate_size = config.intermediate_size
70
+ self.down_proj = nn.Linear(
71
+ config.intermediate_size, config.hidden_size, bias=False
72
+ )
73
+
74
+ # Initialize coefficients with proper shapes
75
+ self.up_coefficients = nn.Parameter(torch.randn(bank.coef_shape))
76
+ self.gate_coefficients = nn.Parameter(torch.randn(bank.coef_shape))
77
+
78
+ # Initialize with small random values instead of ones, then orthogonalize
79
+ nn.init.orthogonal_(self.up_coefficients)
80
+ nn.init.orthogonal_(self.gate_coefficients)
81
+
82
+ if config.mlp_bias:
83
+ self.gate_bias = nn.Parameter(torch.zeros(self.intermediate_size))
84
+ self.up_bias = nn.Parameter(torch.zeros(self.intermediate_size))
85
+ else:
86
+ self.register_parameter("gate_bias", None)
87
+ self.register_parameter("up_bias", None)
88
+
89
+ self.act_fn = F.silu
90
+
91
+ def forward(self, x):
92
+ # Generate weights using template bank
93
+ up_weights, gate_weights = self.bank(
94
+ self.up_coefficients, self.gate_coefficients # Fixed order
95
+ )
96
+ # Match dtype
97
+ target_dtype = x.dtype
98
+ up_weights = up_weights.to(target_dtype)
99
+ gate_weights = gate_weights.to(target_dtype)
100
+ # Apply SwiGLU: SiLU(gate * x) * up * x
101
+ hidden_states = self.act_fn(
102
+ F.linear(x, gate_weights, self.gate_bias)
103
+ ) * F.linear(x, up_weights, self.up_bias)
104
+ output = self.down_proj(hidden_states)
105
+
106
+ return output
107
+
108
+
109
+ class AttTemplateBank(nn.Module):
110
+ def __init__(self, config, coef_rows, coef_columns):
111
+ super().__init__()
112
+ self.hidden_size = config.hidden_size
113
+ self.num_heads = config.num_attention_heads
114
+ self.head_dim = config.hidden_size // config.num_attention_heads
115
+ self.num_key_value_heads = getattr(
116
+ config, "num_key_value_heads", config.num_attention_heads
117
+ )
118
+ self.kv_dim = self.num_key_value_heads * self.head_dim
119
+ self.coef_shape = (coef_rows, coef_columns)
120
+
121
+ # Ensure divisibility
122
+ assert (
123
+ self.hidden_size * self.hidden_size
124
+ ) % coef_rows == 0, "Q projection size must be divisible by coef_rows"
125
+ assert (
126
+ self.kv_dim * self.hidden_size
127
+ ) % coef_rows == 0, "K/V projection size must be divisible by coef_rows"
128
+
129
+ # Create templates for Q, K, V
130
+ self.q_templates = nn.Parameter(
131
+ torch.randn(coef_columns, self.hidden_size * self.hidden_size // coef_rows)
132
+ )
133
+ self.k_templates = nn.Parameter(
134
+ torch.randn(coef_columns, self.kv_dim * self.hidden_size // coef_rows)
135
+ )
136
+ self.v_templates = nn.Parameter(
137
+ torch.randn(coef_columns, self.kv_dim * self.hidden_size // coef_rows)
138
+ )
139
+
140
+ # Initialize templates
141
+ nn.init.xavier_uniform_(self.q_templates)
142
+ nn.init.xavier_uniform_(self.k_templates)
143
+ nn.init.xavier_uniform_(self.v_templates)
144
+
145
+ def forward(self, q_coeffs, k_coeffs, v_coeffs):
146
+ # Compute chunked weights
147
+ q_chunks = torch.matmul(q_coeffs, self.q_templates)
148
+ k_chunks = torch.matmul(k_coeffs, self.k_templates)
149
+ v_chunks = torch.matmul(v_coeffs, self.v_templates)
150
+
151
+ # Reshape to final weight matrices
152
+ q_weights = q_chunks.reshape(self.hidden_size, self.hidden_size)
153
+ k_weights = k_chunks.reshape(self.kv_dim, self.hidden_size)
154
+ v_weights = v_chunks.reshape(self.kv_dim, self.hidden_size)
155
+
156
+ return q_weights, k_weights, v_weights
157
+
158
+
159
+ class SharedLlamaAttention(nn.Module):
160
+ def __init__(
161
+ self,
162
+ config,
163
+ layer_idx: Optional[int] = None,
164
+ bank: Optional[AttTemplateBank] = None,
165
+ ):
166
+ super().__init__()
167
+ self.config = config
168
+ self.bank = bank
169
+ self.layer_idx = layer_idx
170
+ self.attention_dropout = config.attention_dropout
171
+ self.hidden_size = config.hidden_size
172
+ self.num_heads = config.num_attention_heads
173
+ self.head_dim = self.hidden_size // self.num_heads
174
+ self.num_key_value_heads = getattr(
175
+ config, "num_key_value_heads", config.num_attention_heads
176
+ )
177
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
178
+ self.max_position_embeddings = config.max_position_embeddings
179
+ self.rope_theta = getattr(config, "rope_theta", 10000.0)
180
+ self.is_causal = True
181
+
182
+ self.o_proj = nn.Linear(
183
+ self.hidden_size,
184
+ self.hidden_size,
185
+ bias=getattr(config, "attention_bias", False),
186
+ )
187
+ self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
188
+
189
+ # Initialize coefficients with proper shapes
190
+ self.q_coefficients = nn.Parameter(torch.randn(bank.coef_shape))
191
+ self.k_coefficients = nn.Parameter(torch.randn(bank.coef_shape))
192
+ self.v_coefficients = nn.Parameter(torch.randn(bank.coef_shape))
193
+
194
+ # Initialize with small random values
195
+ nn.init.orthogonal_(self.q_coefficients)
196
+ nn.init.orthogonal_(self.k_coefficients)
197
+ nn.init.orthogonal_(self.v_coefficients)
198
+
199
+ def forward(
200
+ self,
201
+ hidden_states,
202
+ attention_mask=None,
203
+ past_key_value=None,
204
+ cache_position=None,
205
+ position_embeddings=None,
206
+ position_ids=None,
207
+ output_attentions=False,
208
+ use_cache=False,
209
+ **kwargs,
210
+ ):
211
+ bsz, q_len, _ = hidden_states.size()
212
+
213
+ # Generate weights using template bank
214
+ q_weights, k_weights, v_weights = self.bank(
215
+ self.q_coefficients, self.k_coefficients, self.v_coefficients
216
+ )
217
+ target_dtype = hidden_states.dtype
218
+ q_weights = q_weights.to(target_dtype)
219
+ k_weights = k_weights.to(target_dtype)
220
+ v_weights = v_weights.to(target_dtype)
221
+
222
+ # Apply projections
223
+ query_states = F.linear(hidden_states, q_weights)
224
+ key_states = F.linear(hidden_states, k_weights)
225
+ value_states = F.linear(hidden_states, v_weights)
226
+
227
+ # Reshape for multi-head attention
228
+ query_states = query_states.view(
229
+ bsz, q_len, self.num_heads, self.head_dim
230
+ ).transpose(1, 2)
231
+ key_states = key_states.view(
232
+ bsz, q_len, self.num_key_value_heads, self.head_dim
233
+ ).transpose(1, 2)
234
+ value_states = value_states.view(
235
+ bsz, q_len, self.num_key_value_heads, self.head_dim
236
+ ).transpose(1, 2)
237
+
238
+ # Apply rotary embeddings
239
+ if position_embeddings is None:
240
+ cos, sin = self.rotary_emb(value_states, position_ids)
241
+ else:
242
+ cos, sin = position_embeddings
243
+
244
+ query_states, key_states = apply_rotary_pos_emb(
245
+ query_states, key_states, cos, sin
246
+ )
247
+
248
+ # Handle past key values
249
+ if past_key_value is not None:
250
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
251
+ key_states, value_states = past_key_value.update(
252
+ key_states, value_states, self.layer_idx, cache_kwargs
253
+ )
254
+
255
+ # Repeat key/value for grouped query attention
256
+
257
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
258
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
259
+
260
+ # ============ CRITICAL CHANGE: Use SDPA instead of manual attention ============
261
+ causal_mask = attention_mask
262
+ if attention_mask is not None:
263
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
264
+
265
+ # Make contiguous for SDPA
266
+ if query_states.device.type == "cuda" and causal_mask is not None:
267
+ query_states = query_states.contiguous()
268
+ key_states = key_states.contiguous()
269
+ value_states = value_states.contiguous()
270
+
271
+ # Determine is_causal for SDPA
272
+ is_causal = True if causal_mask is None and q_len > 1 else False
273
+
274
+ # Use PyTorch's SDPA (same as LlamaSdpaAttention)
275
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
276
+ query_states,
277
+ key_states,
278
+ value_states,
279
+ attn_mask=causal_mask,
280
+ dropout_p=self.attention_dropout if self.training else 0.0,
281
+ is_causal=is_causal,
282
+ )
283
+ # ============================================================================
284
+
285
+ attn_output = attn_output.transpose(1, 2).contiguous()
286
+ attn_output = attn_output.reshape(bsz, q_len, -1)
287
+ attn_output = self.o_proj(attn_output)
288
+
289
+ return attn_output, None, past_key_value
290
+
291
+
292
+ def fixed_cross_entropy(
293
+ source,
294
+ target,
295
+ num_items_in_batch: int = None,
296
+ ignore_index: int = -100,
297
+ **kwargs,
298
+ ):
299
+ reduction = "sum" if num_items_in_batch is not None else "mean"
300
+ loss = nn.functional.cross_entropy(
301
+ source, target, ignore_index=ignore_index, reduction=reduction
302
+ )
303
+ if reduction == "sum":
304
+ loss = loss / num_items_in_batch
305
+ return loss
306
+
307
+
308
+ class RECAST8b_llamaModel(PreTrainedModel):
309
+ config_class = RECAST8b_llama
310
+ base_model_prefix = "llama"
311
+ supports_gradient_checkpointing = True
312
+ _no_split_modules = ["LlamaDecoderLayer"] # Add this
313
+ _skip_keys_device_placement = "past_key_values" # Add this
314
+
315
+ def __init__(self, config):
316
+ super().__init__(config)
317
+ self.padding_idx = config.pad_token_id
318
+ self.vocab_size = config.vocab_size
319
+
320
+ self.embed_tokens = nn.Embedding(
321
+ config.vocab_size, config.hidden_size, self.padding_idx
322
+ )
323
+
324
+ original_config = AutoConfig.from_pretrained(
325
+ "meta-llama/Llama-3.1-8b", trust_remote_code=True
326
+ )
327
+ self.rotary_emb = LlamaRotaryEmbedding(
328
+ config=original_config,
329
+ )
330
+
331
+ # Create template banks first
332
+ self.mlp_banks = []
333
+ self.attn_banks = []
334
+ layers_per_group = config.num_hidden_layers // config.num_groups
335
+ # Explicitly calculate coef_width if not provided in config
336
+ if hasattr(config, "coef_width") and config.coef_width is not None:
337
+ coef_width = config.coef_width
338
+ else:
339
+ coef_width = config.coef_height * layers_per_group
340
+ config.coef_width = coef_width
341
+ print(
342
+ f"Model config: num_groups={config.num_groups}, layers_per_group={layers_per_group}, K={config.k}"
343
+ )
344
+ print(f"Coefficient shape: ({config.coef_height}, {config.coef_width})")
345
+
346
+ mlp_banks = nn.ModuleList(
347
+ [
348
+ MLPTemplateBank(
349
+ config=config, coef_rows=config.coef_height, coef_columns=coef_width
350
+ )
351
+ for _ in range(config.num_groups)
352
+ ]
353
+ )
354
+
355
+ attn_banks = nn.ModuleList(
356
+ [
357
+ AttTemplateBank(
358
+ config=config, coef_rows=config.coef_height, coef_columns=coef_width
359
+ )
360
+ for _ in range(config.num_groups)
361
+ ]
362
+ )
363
+ self.mlp_banks = mlp_banks
364
+ self.attn_banks = attn_banks
365
+ # Create layers using LlamaDecoderLayer but replace MLPs
366
+ self.layers = nn.ModuleList()
367
+ for layer_idx in range(config.num_hidden_layers):
368
+ # Create standard LlamaDecoderLayer
369
+ decoder_layer = LlamaDecoderLayer(config, layer_idx)
370
+ # Replace its MLP with our SharedLlamaMLP
371
+ group_idx = layer_idx // layers_per_group
372
+ decoder_layer.mlp = SharedLlamaMLP(
373
+ config=config,
374
+ bank=self.mlp_banks[group_idx],
375
+ )
376
+ decoder_layer.self_attn = SharedLlamaAttention(
377
+ config=config,
378
+ layer_idx=layer_idx,
379
+ bank=self.attn_banks[group_idx],
380
+ )
381
+
382
+ self.layers.append(decoder_layer)
383
+
384
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
385
+ self.gradient_checkpointing = False
386
+
387
+ def forward(
388
+ self,
389
+ input_ids: torch.LongTensor = None,
390
+ attention_mask: Optional[torch.Tensor] = None,
391
+ position_ids: Optional[torch.LongTensor] = None,
392
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
393
+ inputs_embeds: Optional[torch.FloatTensor] = None,
394
+ use_cache: Optional[bool] = None,
395
+ output_attentions: Optional[bool] = None,
396
+ output_hidden_states: Optional[bool] = None,
397
+ return_dict: Optional[bool] = None,
398
+ cache_position: Optional[torch.LongTensor] = None,
399
+ **flash_attn_kwargs,
400
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
401
+ output_attentions = (
402
+ output_attentions
403
+ if output_attentions is not None
404
+ else self.config.output_attentions
405
+ )
406
+ output_hidden_states = (
407
+ output_hidden_states
408
+ if output_hidden_states is not None
409
+ else self.config.output_hidden_states
410
+ )
411
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
412
+ return_dict = (
413
+ return_dict if return_dict is not None else self.config.use_return_dict
414
+ )
415
+
416
+ if (input_ids is None) ^ (inputs_embeds is not None):
417
+ raise ValueError(
418
+ "You must specify exactly one of input_ids or inputs_embeds"
419
+ )
420
+
421
+ if self.gradient_checkpointing and self.training and use_cache:
422
+ logger.warning_once(
423
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
424
+ )
425
+ use_cache = False
426
+
427
+ if inputs_embeds is None:
428
+ inputs_embeds = self.embed_tokens(input_ids)
429
+ # Set up cache position if not provided
430
+ if cache_position is None:
431
+ past_seen_tokens = (
432
+ 0
433
+ if past_key_values is None
434
+ else (
435
+ past_key_values.get_seq_length()
436
+ if isinstance(past_key_values, Cache)
437
+ else past_key_values[0][0].size(-2) if past_key_values else 0
438
+ )
439
+ )
440
+ cache_position = torch.arange(
441
+ past_seen_tokens,
442
+ past_seen_tokens + inputs_embeds.shape[1],
443
+ device=inputs_embeds.device,
444
+ )
445
+ # Create position embeddings to be shared across the decoder layers
446
+ # Set up position IDs if not provided
447
+ if position_ids is None:
448
+ position_ids = cache_position.unsqueeze(0)
449
+ # Get updated causal mask
450
+ causal_mask = self._update_causal_mask(
451
+ attention_mask,
452
+ inputs_embeds,
453
+ cache_position,
454
+ past_key_values,
455
+ output_attentions,
456
+ )
457
+ hidden_states = inputs_embeds
458
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
459
+
460
+ # Initialize outputs
461
+ all_hidden_states = () if output_hidden_states else None
462
+ all_self_attns = () if output_attentions else None
463
+ next_decoder_cache = None
464
+
465
+ # Process through layers
466
+ for decoder_layer in self.layers:
467
+ if output_hidden_states:
468
+ all_hidden_states += (hidden_states,)
469
+
470
+ if self.gradient_checkpointing and self.training:
471
+ layer_outputs = self._gradient_checkpointing_func(
472
+ decoder_layer.__call__,
473
+ hidden_states,
474
+ causal_mask,
475
+ position_ids,
476
+ past_key_values,
477
+ output_attentions,
478
+ use_cache,
479
+ position_embeddings,
480
+ )
481
+ else:
482
+ layer_outputs = decoder_layer(
483
+ hidden_states,
484
+ attention_mask=causal_mask,
485
+ position_ids=position_ids,
486
+ past_key_value=past_key_values,
487
+ output_attentions=output_attentions,
488
+ use_cache=use_cache,
489
+ position_embeddings=position_embeddings,
490
+ **flash_attn_kwargs,
491
+ )
492
+
493
+ hidden_states = layer_outputs[0]
494
+
495
+ if use_cache:
496
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
497
+
498
+ if output_attentions:
499
+ all_self_attns += (layer_outputs[1],)
500
+
501
+ # Final layer norm
502
+ hidden_states = self.norm(hidden_states)
503
+
504
+ # Add last hidden state
505
+ if output_hidden_states:
506
+ all_hidden_states += (hidden_states,)
507
+
508
+ next_cache = next_decoder_cache if use_cache else None
509
+
510
+ if not return_dict:
511
+ return tuple(
512
+ v
513
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
514
+ if v is not None
515
+ )
516
+
517
+ return BaseModelOutputWithPast(
518
+ last_hidden_state=hidden_states,
519
+ past_key_values=next_cache,
520
+ hidden_states=all_hidden_states,
521
+ attentions=all_self_attns,
522
+ )
523
+
524
+ @classmethod
525
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
526
+ if isinstance(
527
+ pretrained_model_name_or_path, str
528
+ ) and pretrained_model_name_or_path.endswith(".pt"):
529
+ print("Loading from local checkpoint")
530
+ # Load from local checkpoint
531
+ config = kwargs.get("config", None)
532
+ if config is None:
533
+ config = AutoConfig.from_pretrained(
534
+ pretrained_model_name_or_path, trust_remote_code=True
535
+ )
536
+
537
+ model = cls(config)
538
+ checkpoint = torch.load(pretrained_model_name_or_path, map_location="cpu")
539
+ state_dict = checkpoint["model_state_dict"]
540
+ logger.info(
541
+ f"Loaded checkpoint from epoch {checkpoint.get('epoch')} with loss {checkpoint.get('loss')}"
542
+ )
543
+
544
+ missing_keys, unexpected_keys = model.load_state_dict(
545
+ state_dict, strict=False
546
+ )
547
+
548
+ if len(missing_keys) > 0:
549
+ logger.warning(f"Missing keys: {missing_keys}")
550
+ if len(unexpected_keys) > 0:
551
+ logger.warning(f"Unexpected keys: {unexpected_keys}")
552
+
553
+ return model
554
+ else:
555
+ print("Loading from hub")
556
+ # Load from hub using parent's from_pretrained
557
+ return super().from_pretrained(
558
+ pretrained_model_name_or_path, *model_args, **kwargs
559
+ )
560
+
561
+ def get_input_embeddings(self):
562
+ return self.embed_tokens
563
+
564
+ def set_input_embeddings(self, value):
565
+ self.embed_tokens = value
566
+
567
+ def _update_causal_mask(
568
+ self,
569
+ attention_mask: torch.Tensor,
570
+ input_tensor: torch.Tensor,
571
+ cache_position: torch.Tensor,
572
+ past_key_values: Cache,
573
+ output_attentions: bool,
574
+ ):
575
+ if self.config._attn_implementation == "flash_attention_2":
576
+ if attention_mask is not None and 0.0 in attention_mask:
577
+ return attention_mask
578
+ return None
579
+
580
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
581
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
582
+ # to infer the attention mask.
583
+ past_seen_tokens = (
584
+ past_key_values.get_seq_length() if past_key_values is not None else 0
585
+ )
586
+ using_static_cache = isinstance(past_key_values, StaticCache)
587
+
588
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
589
+ if (
590
+ self.config._attn_implementation == "sdpa"
591
+ and not using_static_cache
592
+ and not output_attentions
593
+ ):
594
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
595
+ attention_mask,
596
+ inputs_embeds=input_tensor,
597
+ past_key_values_length=past_seen_tokens,
598
+ is_training=self.training,
599
+ ):
600
+ return None
601
+
602
+ dtype, device = input_tensor.dtype, input_tensor.device
603
+ sequence_length = input_tensor.shape[1]
604
+ if using_static_cache:
605
+ target_length = past_key_values.get_max_cache_shape()
606
+ else:
607
+ target_length = (
608
+ attention_mask.shape[-1]
609
+ if isinstance(attention_mask, torch.Tensor)
610
+ else past_seen_tokens + sequence_length + 1
611
+ )
612
+
613
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
614
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
615
+ attention_mask,
616
+ sequence_length=sequence_length,
617
+ target_length=target_length,
618
+ dtype=dtype,
619
+ device=device,
620
+ cache_position=cache_position,
621
+ batch_size=input_tensor.shape[0],
622
+ )
623
+
624
+ if (
625
+ self.config._attn_implementation == "sdpa"
626
+ and attention_mask is not None
627
+ and attention_mask.device.type == "cuda"
628
+ and not output_attentions
629
+ ):
630
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
631
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
632
+ # Details: https://github.com/pytorch/pytorch/issues/110213
633
+ min_dtype = torch.finfo(dtype).min
634
+ causal_mask = AttentionMaskConverter._unmask_unattended(
635
+ causal_mask, min_dtype
636
+ )
637
+
638
+ return causal_mask
639
+
640
+ @staticmethod
641
+ def _prepare_4d_causal_attention_mask_with_cache_position(
642
+ attention_mask: torch.Tensor,
643
+ sequence_length: int,
644
+ target_length: int,
645
+ dtype: torch.dtype,
646
+ device: torch.device,
647
+ cache_position: torch.Tensor,
648
+ batch_size: int,
649
+ **kwargs,
650
+ ):
651
+ if attention_mask is not None and attention_mask.dim() == 4:
652
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
653
+ causal_mask = attention_mask
654
+ else:
655
+ min_dtype = torch.finfo(dtype).min
656
+ causal_mask = torch.full(
657
+ (sequence_length, target_length),
658
+ fill_value=min_dtype,
659
+ dtype=dtype,
660
+ device=device,
661
+ )
662
+ if sequence_length != 1:
663
+ causal_mask = torch.triu(causal_mask, diagonal=1)
664
+ causal_mask *= torch.arange(
665
+ target_length, device=device
666
+ ) > cache_position.reshape(-1, 1)
667
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
668
+ if attention_mask is not None:
669
+ causal_mask = (
670
+ causal_mask.clone()
671
+ ) # copy to contiguous memory for in-place edit
672
+ mask_length = attention_mask.shape[-1]
673
+ padding_mask = (
674
+ causal_mask[:, :, :, :mask_length]
675
+ + attention_mask[:, None, None, :]
676
+ )
677
+ padding_mask = padding_mask == 0
678
+ causal_mask[:, :, :, :mask_length] = causal_mask[
679
+ :, :, :, :mask_length
680
+ ].masked_fill(padding_mask, min_dtype)
681
+
682
+ return causal_mask
683
+
684
+
685
+ class RECAST8b_LlamaForCausalLM(PreTrainedModel, GenerationMixin):
686
+ _tied_weights_keys = ["lm_head.weight"]
687
+ _tp_plan = {"lm_head": "colwise_rep"}
688
+ config_class = RECAST8b_llama
689
+ base_model_prefix = "llama"
690
+ supports_gradient_checkpointing = True
691
+ _no_split_modules = ["LlamaDecoderLayer"] # Add this
692
+ _skip_keys_device_placement = "past_key_values" # Add this
693
+
694
+ def __init__(self, config):
695
+ super().__init__(config)
696
+ self.model = RECAST8b_llamaModel(config)
697
+ self.vocab_size = config.vocab_size
698
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
699
+
700
+ # Initialize weights and apply final processing
701
+ self.post_init()
702
+
703
+ def get_input_embeddings(self):
704
+ return self.model.embed_tokens
705
+
706
+ def set_input_embeddings(self, value):
707
+ self.model.embed_tokens = value
708
+
709
+ def get_output_embeddings(self):
710
+ return self.lm_head
711
+
712
+ def set_output_embeddings(self, new_embeddings):
713
+ self.lm_head = new_embeddings
714
+
715
+ def set_decoder(self, decoder):
716
+ self.model = decoder
717
+
718
+ def get_decoder(self):
719
+ return self.model
720
+
721
+ def loss_function(
722
+ self,
723
+ logits,
724
+ labels,
725
+ vocab_size: int,
726
+ num_items_in_batch: int = None,
727
+ ignore_index: int = -100,
728
+ **kwargs,
729
+ ):
730
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
731
+ logits = logits.float()
732
+ # Shift so that tokens < n predict n
733
+ shift_logits = logits[..., :-1, :].contiguous()
734
+ shift_labels = labels[..., 1:].contiguous()
735
+ # Flatten the tokens
736
+ shift_logits = shift_logits.view(-1, vocab_size)
737
+ shift_labels = shift_labels.view(-1)
738
+ # Enable model parallelism
739
+ shift_labels = shift_labels.to(shift_logits.device)
740
+ loss = fixed_cross_entropy(
741
+ shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs
742
+ )
743
+ return loss
744
+
745
+ def forward(
746
+ self,
747
+ input_ids: torch.LongTensor = None,
748
+ attention_mask: Optional[torch.Tensor] = None,
749
+ position_ids: Optional[torch.LongTensor] = None,
750
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
751
+ inputs_embeds: Optional[torch.FloatTensor] = None,
752
+ labels: Optional[torch.LongTensor] = None,
753
+ use_cache: Optional[bool] = None,
754
+ output_attentions: Optional[bool] = None,
755
+ output_hidden_states: Optional[bool] = None,
756
+ return_dict: Optional[bool] = None,
757
+ cache_position: Optional[torch.LongTensor] = None,
758
+ num_logits_to_keep: int = 0,
759
+ **kwargs,
760
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
761
+ """
762
+ Args:
763
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
764
+ Labels for computing the masked language modeling loss. Indices should be in
765
+ `[0, ..., config.vocab_size]` or -100 (masked tokens).
766
+ num_logits_to_keep (`int`, *optional*):
767
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate all logits.
768
+ """
769
+ output_attentions = (
770
+ output_attentions
771
+ if output_attentions is not None
772
+ else self.config.output_attentions
773
+ )
774
+ output_hidden_states = (
775
+ output_hidden_states
776
+ if output_hidden_states is not None
777
+ else self.config.output_hidden_states
778
+ )
779
+ return_dict = (
780
+ return_dict if return_dict is not None else self.config.use_return_dict
781
+ )
782
+
783
+ outputs = self.model(
784
+ input_ids=input_ids,
785
+ attention_mask=attention_mask,
786
+ position_ids=position_ids,
787
+ past_key_values=past_key_values,
788
+ inputs_embeds=inputs_embeds,
789
+ use_cache=use_cache,
790
+ output_attentions=output_attentions,
791
+ output_hidden_states=output_hidden_states,
792
+ return_dict=return_dict,
793
+ cache_position=cache_position,
794
+ **kwargs,
795
+ )
796
+
797
+ hidden_states = outputs[0]
798
+ # Only compute necessary logits
799
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
800
+
801
+ loss = None
802
+ if labels is not None:
803
+ # Calculate batch size for loss function
804
+ num_items_in_batch = (
805
+ input_ids.size(0) if input_ids is not None else inputs_embeds.size(0)
806
+ )
807
+ loss = self.loss_function(
808
+ logits=logits,
809
+ labels=labels,
810
+ vocab_size=self.config.vocab_size,
811
+ num_items_in_batch=num_items_in_batch,
812
+ **kwargs,
813
+ )
814
+
815
+ if not return_dict:
816
+ output = (logits,) + outputs[1:]
817
+ return (loss,) + output if loss is not None else output
818
+
819
+ return CausalLMOutputWithPast(
820
+ loss=loss,
821
+ logits=logits,
822
+ past_key_values=outputs.past_key_values,
823
+ hidden_states=outputs.hidden_states,
824
+ attentions=outputs.attentions,
825
+ )
826
+
827
+ def prepare_inputs_for_generation(
828
+ self,
829
+ input_ids,
830
+ past_key_values=None,
831
+ attention_mask=None,
832
+ inputs_embeds=None,
833
+ **kwargs,
834
+ ):
835
+ if past_key_values:
836
+ input_ids = input_ids[:, -1:]
837
+
838
+ position_ids = kwargs.get("position_ids", None)
839
+ if attention_mask is not None and position_ids is None:
840
+ # create position_ids on the fly for batch generation
841
+ position_ids = attention_mask.long().cumsum(-1) - 1
842
+ position_ids.masked_fill_(attention_mask == 0, 1)
843
+ if past_key_values:
844
+ position_ids = position_ids[:, -1].unsqueeze(-1)
845
+
846
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
847
+ if inputs_embeds is not None and past_key_values is None:
848
+ model_inputs = {"inputs_embeds": inputs_embeds}
849
+ else:
850
+ model_inputs = {"input_ids": input_ids}
851
+
852
+ model_inputs.update(
853
+ {
854
+ "position_ids": position_ids,
855
+ "past_key_values": past_key_values,
856
+ "use_cache": kwargs.get("use_cache"),
857
+ "attention_mask": attention_mask,
858
+ }
859
+ )
860
+ return model_inputs
861
+
862
+ @classmethod
863
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
864
+ if isinstance(
865
+ pretrained_model_name_or_path, str
866
+ ) and pretrained_model_name_or_path.endswith(".pt"):
867
+ print("Loading from local checkpoint")
868
+ config = kwargs.get("config", None)
869
+ if config is None:
870
+ config = AutoConfig.from_pretrained(
871
+ pretrained_model_name_or_path, trust_remote_code=True
872
+ )
873
+ model = torch.load(pretrained_model_name_or_path, map_location="cpu")
874
+
875
+ return model
876
+ else:
877
+ print("Loading from hub")
878
+ return super().from_pretrained(
879
+ pretrained_model_name_or_path, *model_args, **kwargs
880
+ )