Sayan01 commited on
Commit
dbbfe1c
·
verified ·
1 Parent(s): a353f87

Update modeling_mobillama.py

Browse files
Files changed (1) hide show
  1. modeling_mobillama.py +272 -787
modeling_mobillama.py CHANGED
@@ -1,10 +1,10 @@
1
  # coding=utf-8
2
  # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
 
 
3
  #
4
  # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
- # and OPT implementations in this library. It has been modified from its
6
- # original forms to accommodate minor architectural differences compared
7
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
  #
9
  # Licensed under the Apache License, Version 2.0 (the "License");
10
  # you may not use this file except in compliance with the License.
@@ -17,852 +17,337 @@
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
  # See the License for the specific language governing permissions and
19
  # limitations under the License.
20
- """ PyTorch LLaMA model."""
21
- import math
22
- from typing import List, Optional, Tuple, Union
 
 
23
 
24
  import torch
25
- import torch.utils.checkpoint
26
  from torch import nn
27
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
-
29
- from transformers.activations import ACT2FN
30
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
31
- from transformers.modeling_utils import PreTrainedModel
32
- from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
33
- from transformers.models.llama.configuration_llama import LlamaConfig
34
-
35
- # from .configuration_mobillama import MobiLlamaConfig
36
-
37
- from flash_attn import flash_attn_func
38
-
39
-
40
- logger = logging.get_logger(__name__)
41
-
42
- _CONFIG_FOR_DOC = "LlamaConfig"
43
-
44
-
45
- # Copied from transformers.models.bart.modeling_bart._make_causal_mask
46
- def _make_causal_mask(
47
- input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
48
- ):
49
- """
50
- Make causal mask used for bi-directional self-attention.
51
- """
52
- bsz, tgt_len = input_ids_shape
53
- mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
54
- mask_cond = torch.arange(mask.size(-1), device=device)
55
- mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
56
- mask = mask.to(dtype)
57
-
58
- if past_key_values_length > 0:
59
- mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
60
- return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
61
-
62
-
63
- # Copied from transformers.models.bart.modeling_bart._expand_mask
64
- def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
65
- """
66
- Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
67
- """
68
- bsz, src_len = mask.size()
69
- tgt_len = tgt_len if tgt_len is not None else src_len
70
 
71
- expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
 
 
 
 
 
 
72
 
73
- inverted_mask = 1.0 - expanded_mask
74
-
75
- return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
76
-
77
-
78
- class MobiLlamaRMSNorm(nn.Module):
79
- def __init__(self, hidden_size, eps=1e-6):
80
- """
81
- MobiLlamaRMSNorm is equivalent to T5LayerNorm
82
- """
83
- super().__init__()
84
- self.weight = nn.Parameter(torch.ones(hidden_size))
85
- self.variance_epsilon = eps
86
-
87
- def forward(self, hidden_states):
88
- input_dtype = hidden_states.dtype
89
- variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
90
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
91
-
92
- return (self.weight * hidden_states).to(input_dtype)
93
-
94
-
95
- class MobiLlamaRotaryEmbedding(torch.nn.Module):
96
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
97
- super().__init__()
98
- inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
99
- self.register_buffer("inv_freq", inv_freq)
100
-
101
- # Build here to make `torch.jit.trace` work.
102
- self.max_seq_len_cached = max_position_embeddings
103
- t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
104
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
105
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
106
- emb = torch.cat((freqs, freqs), dim=-1)
107
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
108
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
109
-
110
- def forward(self, x, seq_len=None):
111
- # x: [bs, num_attention_heads, seq_len, head_size]
112
- # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
113
- if seq_len > self.max_seq_len_cached:
114
- self.max_seq_len_cached = seq_len
115
- t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
116
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
117
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
118
- emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
119
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
120
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
121
- return (
122
- self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
123
- self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
124
- )
125
-
126
-
127
- def rotate_half(x):
128
- """Rotates half the hidden dims of the input."""
129
- x1 = x[..., : x.shape[-1] // 2]
130
- x2 = x[..., x.shape[-1] // 2 :]
131
- return torch.cat((-x2, x1), dim=-1)
132
-
133
-
134
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
135
- # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
136
- cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
137
- sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
138
- cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
139
- sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
140
- q_embed = (q * cos) + (rotate_half(q) * sin)
141
- k_embed = (k * cos) + (rotate_half(k) * sin)
142
- return q_embed, k_embed
143
 
 
 
144
 
145
- class MobiLlamaMLP(nn.Module):
146
  def __init__(
147
  self,
148
- hidden_size: int,
149
- intermediate_size: int,
150
- hidden_act: str,
151
  ):
152
- super().__init__()
153
- self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
154
- self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
155
- self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
156
- self.act_fn = ACT2FN[hidden_act]
157
-
158
- def forward(self, x):
159
- return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
160
-
161
-
162
- class MobiLlamaAttention(nn.Module):
163
- """Multi-headed attention from 'Attention Is All You Need' paper"""
164
-
165
- def __init__(self, config: LlamaConfig):
166
  super().__init__()
167
  self.config = config
 
 
168
  self.hidden_size = config.hidden_size
169
  self.num_heads = config.num_attention_heads
170
  self.head_dim = self.hidden_size // self.num_heads
171
  self.max_position_embeddings = config.max_position_embeddings
 
172
 
173
  if (self.head_dim * self.num_heads) != self.hidden_size:
174
  raise ValueError(
175
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
176
  f" and `num_heads`: {self.num_heads})."
177
  )
178
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
179
- self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
180
- self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
181
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
182
- self.rotary_emb = MobiLlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
183
-
184
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
185
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  def forward(
188
  self,
 
189
  hidden_states: torch.Tensor,
190
- attention_mask: Optional[torch.Tensor] = None,
191
- position_ids: Optional[torch.LongTensor] = None,
192
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
193
- output_attentions: bool = False,
194
- use_cache: bool = False,
195
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
196
- bsz, q_len, _ = hidden_states.size()
197
-
198
- query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
199
- key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
200
- value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
201
-
202
- kv_seq_len = key_states.shape[-2]
203
- if past_key_value is not None:
204
- kv_seq_len += past_key_value[0].shape[-2]
205
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
206
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
207
- # [bsz, nh, t, hd]
208
-
209
- if past_key_value is not None:
210
- # reuse k, v, self_attention
211
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
212
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
213
-
214
- past_key_value = (key_states, value_states) if use_cache else None
215
-
216
- attn_output = flash_attn_func(
217
- q=query_states.transpose(1, 2).to(torch.bfloat16),
218
- k=key_states.transpose(1, 2).to(torch.bfloat16),
219
- v=value_states.transpose(1, 2).to(torch.bfloat16),
220
- causal=True)
221
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
222
- attn_output = attn_output.to(query_states.dtype)
223
-
224
  attn_output = self.o_proj(attn_output)
 
 
225
 
226
- # if not output_attentions:
227
- # attn_weights = None
228
- assert not output_attentions
229
- attn_weights = None
230
-
231
- return attn_output, attn_weights, past_key_value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
 
 
 
 
 
 
 
 
 
 
233
 
234
  class MobiLlamaDecoderLayer(nn.Module):
235
- def __init__(self, config: LlamaConfig):
 
 
 
 
236
  super().__init__()
237
  self.hidden_size = config.hidden_size
238
- self.self_attn = MobiLlamaAttention(config=config)
239
- self.mlp = MobiLlamaMLP(
240
- hidden_size=self.hidden_size,
241
- intermediate_size=config.intermediate_size,
242
- hidden_act=config.hidden_act,
243
- )
244
- self.input_layernorm = MobiLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
245
- self.post_attention_layernorm = MobiLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
 
246
 
247
  def forward(
248
  self,
 
249
  hidden_states: torch.Tensor,
250
- attention_mask: Optional[torch.Tensor] = None,
251
- position_ids: Optional[torch.LongTensor] = None,
252
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
253
- output_attentions: Optional[bool] = False,
254
- use_cache: Optional[bool] = False,
255
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
256
- """
257
- Args:
258
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
259
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
260
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
261
- output_attentions (`bool`, *optional*):
262
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
263
- returned tensors for more detail.
264
- use_cache (`bool`, *optional*):
265
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
266
- (see `past_key_values`).
267
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
268
- """
269
-
270
  residual = hidden_states
271
-
272
  hidden_states = self.input_layernorm(hidden_states)
273
-
274
- # Self Attention
275
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
 
276
  hidden_states=hidden_states,
277
- attention_mask=attention_mask,
278
- position_ids=position_ids,
279
- past_key_value=past_key_value,
280
- output_attentions=output_attentions,
281
- use_cache=use_cache,
282
  )
 
 
283
  hidden_states = residual + hidden_states
284
-
285
- # Fully Connected
286
  residual = hidden_states
287
  hidden_states = self.post_attention_layernorm(hidden_states)
 
 
288
  hidden_states = self.mlp(hidden_states)
 
 
289
  hidden_states = residual + hidden_states
 
 
290
 
291
- outputs = (hidden_states,)
292
-
293
- if output_attentions:
294
- outputs += (self_attn_weights,)
295
-
296
- if use_cache:
297
- outputs += (present_key_value,)
298
-
299
- return outputs
300
-
301
-
302
- MOBILLAMA_START_DOCSTRING = r"""
303
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
304
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
305
- etc.)
306
-
307
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
308
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
309
- and behavior.
310
-
311
- Parameters:
312
- config ([`LlamaConfig`]):
313
- Model configuration class with all the parameters of the model. Initializing with a config file does not
314
- load the weights associated with the model, only the configuration. Check out the
315
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
316
- """
317
-
318
-
319
- @add_start_docstrings(
320
- "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
321
- MOBILLAMA_START_DOCSTRING,
322
- )
323
- class MobiLlamaPreTrainedModel(PreTrainedModel):
324
- config_class = LlamaConfig
325
- base_model_prefix = "model"
326
- supports_gradient_checkpointing = True
327
- _no_split_modules = ["MobiLlamaDecoderLayer"]
328
- _skip_keys_device_placement = "past_key_values"
329
- _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
330
-
331
- def _init_weights(self, module):
332
- std = self.config.initializer_range
333
- if isinstance(module, nn.Linear):
334
- module.weight.data.normal_(mean=0.0, std=std)
335
- if module.bias is not None:
336
- module.bias.data.zero_()
337
- elif isinstance(module, nn.Embedding):
338
- module.weight.data.normal_(mean=0.0, std=std)
339
- if module.padding_idx is not None:
340
- module.weight.data[module.padding_idx].zero_()
341
-
342
- def _set_gradient_checkpointing(self, module, value=False):
343
- if isinstance(module, MobiLlamaModel):
344
- module.gradient_checkpointing = value
345
-
346
-
347
- MOBILLAMA_INPUTS_DOCSTRING = r"""
348
- Args:
349
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
350
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
351
- it.
352
-
353
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
354
- [`PreTrainedTokenizer.__call__`] for details.
355
-
356
- [What are input IDs?](../glossary#input-ids)
357
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
358
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
359
-
360
- - 1 for tokens that are **not masked**,
361
- - 0 for tokens that are **masked**.
362
-
363
- [What are attention masks?](../glossary#attention-mask)
364
-
365
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
366
- [`PreTrainedTokenizer.__call__`] for details.
367
-
368
- If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
369
- `past_key_values`).
370
-
371
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
372
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
373
- information on the default strategy.
374
-
375
- - 1 indicates the head is **not masked**,
376
- - 0 indicates the head is **masked**.
377
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
378
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
379
- config.n_positions - 1]`.
380
-
381
- [What are position IDs?](../glossary#position-ids)
382
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
383
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
384
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
385
- `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
386
-
387
- Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
388
- blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
389
-
390
- If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
391
- don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
392
- `decoder_input_ids` of shape `(batch_size, sequence_length)`.
393
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
394
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
395
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
396
- model's internal embedding lookup matrix.
397
- use_cache (`bool`, *optional*):
398
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
399
- `past_key_values`).
400
- output_attentions (`bool`, *optional*):
401
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
402
- tensors for more detail.
403
- output_hidden_states (`bool`, *optional*):
404
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
405
- more detail.
406
- return_dict (`bool`, *optional*):
407
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
408
- """
409
-
410
-
411
- @add_start_docstrings(
412
- "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
413
- MOBILLAMA_START_DOCSTRING,
414
- )
415
- class MobiLlamaModel(MobiLlamaPreTrainedModel):
416
- """
417
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MobiLlamaDecoderLayer`]
418
-
419
- Args:
420
- config: LlamaConfig
421
- """
422
-
423
  def __init__(self, config: LlamaConfig):
424
- super().__init__(config)
 
425
  self.padding_idx = config.pad_token_id
426
  self.vocab_size = config.vocab_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
 
428
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
429
- self.layers = nn.ModuleList([MobiLlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
430
- self.norm = MobiLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
431
-
432
- self.gradient_checkpointing = False
433
- # Initialize weights and apply final processing
434
- self.post_init()
435
-
436
- def get_input_embeddings(self):
437
- return self.embed_tokens
438
-
439
- def set_input_embeddings(self, value):
440
- self.embed_tokens = value
441
-
442
- # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
443
- def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
444
- # create causal mask
445
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
446
- combined_attention_mask = None
447
- if input_shape[-1] > 1:
448
- combined_attention_mask = _make_causal_mask(
449
- input_shape,
450
- inputs_embeds.dtype,
451
- device=inputs_embeds.device,
452
- past_key_values_length=past_key_values_length,
453
- )
454
-
455
- if attention_mask is not None:
456
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
457
- expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
458
- inputs_embeds.device
459
- )
460
- combined_attention_mask = (
461
- expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
462
- )
463
-
464
- return combined_attention_mask
465
-
466
- @add_start_docstrings_to_model_forward(MOBILLAMA_INPUTS_DOCSTRING)
467
  def forward(
468
  self,
469
- input_ids: torch.LongTensor = None,
470
- attention_mask: Optional[torch.Tensor] = None,
471
- position_ids: Optional[torch.LongTensor] = None,
472
- past_key_values: Optional[List[torch.FloatTensor]] = None,
473
- inputs_embeds: Optional[torch.FloatTensor] = None,
474
- use_cache: Optional[bool] = None,
475
- output_attentions: Optional[bool] = None,
476
- output_hidden_states: Optional[bool] = None,
477
- return_dict: Optional[bool] = None,
478
- ) -> Union[Tuple, BaseModelOutputWithPast]:
479
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
480
- output_hidden_states = (
481
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
482
- )
483
- use_cache = use_cache if use_cache is not None else self.config.use_cache
484
-
485
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
486
-
487
- # retrieve input_ids and inputs_embeds
488
- if input_ids is not None and inputs_embeds is not None:
489
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
490
- elif input_ids is not None:
491
- batch_size, seq_length = input_ids.shape
492
- elif inputs_embeds is not None:
493
- batch_size, seq_length, _ = inputs_embeds.shape
494
- else:
495
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
496
-
497
- seq_length_with_past = seq_length
498
- past_key_values_length = 0
499
-
500
- if past_key_values is not None:
501
- past_key_values_length = past_key_values[0][0].shape[2]
502
- seq_length_with_past = seq_length_with_past + past_key_values_length
503
-
504
- if position_ids is None:
505
- device = input_ids.device if input_ids is not None else inputs_embeds.device
506
- position_ids = torch.arange(
507
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
508
- )
509
- position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
510
- else:
511
- position_ids = position_ids.view(-1, seq_length).long()
512
-
513
- if inputs_embeds is None:
514
- inputs_embeds = self.embed_tokens(input_ids)
515
- # embed positions
516
- if attention_mask is None:
517
- attention_mask = torch.ones(
518
- (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
519
  )
520
- attention_mask = self._prepare_decoder_attention_mask(
521
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
522
- )
523
-
524
- hidden_states = inputs_embeds
525
-
526
- if self.gradient_checkpointing and self.training:
527
- if use_cache:
528
- logger.warning_once(
529
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
530
- )
531
- use_cache = False
532
-
533
- # decoder layers
534
- all_hidden_states = () if output_hidden_states else None
535
- all_self_attns = () if output_attentions else None
536
- next_decoder_cache = () if use_cache else None
537
-
538
- for idx, decoder_layer in enumerate(self.layers):
539
- if output_hidden_states:
540
- all_hidden_states += (hidden_states,)
541
-
542
- past_key_value = past_key_values[idx] if past_key_values is not None else None
543
-
544
- if self.gradient_checkpointing and self.training:
545
-
546
- def create_custom_forward(module):
547
- def custom_forward(*inputs):
548
- # None for past_key_value
549
- return module(*inputs, output_attentions, None)
550
-
551
- return custom_forward
552
-
553
- layer_outputs = torch.utils.checkpoint.checkpoint(
554
- create_custom_forward(decoder_layer),
555
- hidden_states,
556
- attention_mask,
557
- position_ids,
558
- None,
559
- )
560
- else:
561
- layer_outputs = decoder_layer(
562
- hidden_states,
563
- attention_mask=attention_mask,
564
- position_ids=position_ids,
565
- past_key_value=past_key_value,
566
- output_attentions=output_attentions,
567
- use_cache=use_cache,
568
- )
569
-
570
- hidden_states = layer_outputs[0]
571
-
572
- if use_cache:
573
- next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
574
-
575
- if output_attentions:
576
- all_self_attns += (layer_outputs[1],)
577
-
578
  hidden_states = self.norm(hidden_states)
 
 
579
 
580
- # add hidden states from the last decoder layer
581
- if output_hidden_states:
582
- all_hidden_states += (hidden_states,)
583
-
584
- next_cache = next_decoder_cache if use_cache else None
585
- if not return_dict:
586
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
587
- return BaseModelOutputWithPast(
588
- last_hidden_state=hidden_states,
589
- past_key_values=next_cache,
590
- hidden_states=all_hidden_states,
591
- attentions=all_self_attns,
592
- )
593
-
594
-
595
- class LlamaForCausalLM(MobiLlamaPreTrainedModel):
596
- def __init__(self, config):
597
- super().__init__(config)
598
  self.model = MobiLlamaModel(config)
 
 
 
 
 
 
 
 
 
 
599
 
600
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
601
-
602
- # Initialize weights and apply final processing
603
- self.post_init()
604
-
605
- def get_input_embeddings(self):
606
- return self.model.embed_tokens
607
-
608
- def set_input_embeddings(self, value):
609
- self.model.embed_tokens = value
610
-
611
- def get_output_embeddings(self):
612
- return self.lm_head
613
-
614
- def set_output_embeddings(self, new_embeddings):
615
- self.lm_head = new_embeddings
616
-
617
- def set_decoder(self, decoder):
618
- self.model = decoder
619
-
620
- def get_decoder(self):
621
- return self.model
622
-
623
- @add_start_docstrings_to_model_forward(MOBILLAMA_INPUTS_DOCSTRING)
624
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
625
  def forward(
626
  self,
627
- input_ids: torch.LongTensor = None,
628
- attention_mask: Optional[torch.Tensor] = None,
629
- position_ids: Optional[torch.LongTensor] = None,
630
- past_key_values: Optional[List[torch.FloatTensor]] = None,
631
- inputs_embeds: Optional[torch.FloatTensor] = None,
632
- labels: Optional[torch.LongTensor] = None,
633
- use_cache: Optional[bool] = None,
634
- output_attentions: Optional[bool] = None,
635
- output_hidden_states: Optional[bool] = None,
636
- return_dict: Optional[bool] = None,
637
- ) -> Union[Tuple, CausalLMOutputWithPast]:
638
- r"""
639
- Args:
640
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
641
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
642
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
643
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
644
-
645
- Returns:
646
-
647
- Example:
648
-
649
- ```python
650
- >>> from transformers import AutoTokenizer, LlamaForCausalLM
651
-
652
- >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
653
- >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
654
-
655
- >>> prompt = "Hey, are you consciours? Can you talk to me?"
656
- >>> inputs = tokenizer(prompt, return_tensors="pt")
657
-
658
- >>> # Generate
659
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
660
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
661
- "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
662
- ```"""
663
-
664
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
665
- output_hidden_states = (
666
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
667
- )
668
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
669
-
670
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
671
- outputs = self.model(
672
  input_ids=input_ids,
673
- attention_mask=attention_mask,
674
- position_ids=position_ids,
675
- past_key_values=past_key_values,
676
- inputs_embeds=inputs_embeds,
677
- use_cache=use_cache,
678
- output_attentions=output_attentions,
679
- output_hidden_states=output_hidden_states,
680
- return_dict=return_dict,
681
  )
682
-
683
- hidden_states = outputs[0]
684
  logits = self.lm_head(hidden_states)
 
 
685
 
686
- loss = None
687
- if labels is not None:
688
- # Shift so that tokens < n predict n
689
- shift_logits = logits[..., :-1, :].contiguous()
690
- shift_labels = labels[..., 1:].contiguous()
691
- # Flatten the tokens
692
- loss_fct = CrossEntropyLoss()
693
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
694
- shift_labels = shift_labels.view(-1)
695
- # Enable model parallelism
696
- shift_labels = shift_labels.to(shift_logits.device)
697
- loss = loss_fct(shift_logits, shift_labels)
698
-
699
- if not return_dict:
700
- output = (logits,) + outputs[1:]
701
- return (loss,) + output if loss is not None else output
702
-
703
- return CausalLMOutputWithPast(
704
- loss=loss,
705
- logits=logits,
706
- past_key_values=outputs.past_key_values,
707
- hidden_states=outputs.hidden_states,
708
- attentions=outputs.attentions,
709
- )
710
-
711
- def prepare_inputs_for_generation(
712
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
713
- ):
714
- if past_key_values:
715
- input_ids = input_ids[:, -1:]
716
-
717
- position_ids = kwargs.get("position_ids", None)
718
- if attention_mask is not None and position_ids is None:
719
- # create position_ids on the fly for batch generation
720
- position_ids = attention_mask.long().cumsum(-1) - 1
721
- position_ids.masked_fill_(attention_mask == 0, 1)
722
- if past_key_values:
723
- position_ids = position_ids[:, -1].unsqueeze(-1)
724
-
725
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
726
- if inputs_embeds is not None and past_key_values is None:
727
- model_inputs = {"inputs_embeds": inputs_embeds}
728
- else:
729
- model_inputs = {"input_ids": input_ids}
730
-
731
- model_inputs.update(
732
- {
733
- "position_ids": position_ids,
734
- "past_key_values": past_key_values,
735
- "use_cache": kwargs.get("use_cache"),
736
- "attention_mask": attention_mask,
737
- }
738
- )
739
- return model_inputs
740
-
741
- @staticmethod
742
- def _reorder_cache(past_key_values, beam_idx):
743
- reordered_past = ()
744
- for layer_past in past_key_values:
745
- reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
746
- return reordered_past
747
-
748
-
749
- @add_start_docstrings(
750
- """
751
- The LLaMa Model transformer with a sequence classification head on top (linear layer).
752
-
753
- [`MobiLlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
754
- (e.g. GPT-2) do.
755
-
756
- Since it does classification on the last token, it requires to know the position of the last token. If a
757
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
758
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
759
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
760
- each row of the batch).
761
- """,
762
- MOBILLAMA_START_DOCSTRING,
763
- )
764
- class MobiLlamaForSequenceClassification(MobiLlamaPreTrainedModel):
765
- _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
766
-
767
- def __init__(self, config):
768
- super().__init__(config)
769
- self.num_labels = config.num_labels
770
- self.model = MobiLlamaModel(config)
771
- self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
772
-
773
- # Initialize weights and apply final processing
774
- self.post_init()
775
-
776
- def get_input_embeddings(self):
777
- return self.model.embed_tokens
778
-
779
- def set_input_embeddings(self, value):
780
- self.model.embed_tokens = value
781
-
782
- @add_start_docstrings_to_model_forward(MOBILLAMA_INPUTS_DOCSTRING)
783
- def forward(
784
  self,
785
- input_ids: torch.LongTensor = None,
786
- attention_mask: Optional[torch.Tensor] = None,
787
- position_ids: Optional[torch.LongTensor] = None,
788
- past_key_values: Optional[List[torch.FloatTensor]] = None,
789
- inputs_embeds: Optional[torch.FloatTensor] = None,
790
- labels: Optional[torch.LongTensor] = None,
791
- use_cache: Optional[bool] = None,
792
- output_attentions: Optional[bool] = None,
793
- output_hidden_states: Optional[bool] = None,
794
- return_dict: Optional[bool] = None,
795
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
796
- r"""
797
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
798
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
799
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
800
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
801
- """
802
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
803
-
804
- transformer_outputs = self.model(
805
- input_ids,
806
- attention_mask=attention_mask,
807
- position_ids=position_ids,
808
- past_key_values=past_key_values,
809
- inputs_embeds=inputs_embeds,
810
- use_cache=use_cache,
811
- output_attentions=output_attentions,
812
- output_hidden_states=output_hidden_states,
813
- return_dict=return_dict,
814
- )
815
- hidden_states = transformer_outputs[0]
816
- logits = self.score(hidden_states)
817
-
818
- if input_ids is not None:
819
- batch_size = input_ids.shape[0]
820
- else:
821
- batch_size = inputs_embeds.shape[0]
822
-
823
- if self.config.pad_token_id is None and batch_size != 1:
824
- raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
825
- if self.config.pad_token_id is None:
826
- sequence_lengths = -1
827
- else:
828
- if input_ids is not None:
829
- sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
830
- else:
831
- sequence_lengths = -1
832
-
833
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
834
-
835
- loss = None
836
- if labels is not None:
837
- labels = labels.to(logits.device)
838
- if self.config.problem_type is None:
839
- if self.num_labels == 1:
840
- self.config.problem_type = "regression"
841
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
842
- self.config.problem_type = "single_label_classification"
843
- else:
844
- self.config.problem_type = "multi_label_classification"
845
-
846
- if self.config.problem_type == "regression":
847
- loss_fct = MSELoss()
848
- if self.num_labels == 1:
849
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
850
- else:
851
- loss = loss_fct(pooled_logits, labels)
852
- elif self.config.problem_type == "single_label_classification":
853
- loss_fct = CrossEntropyLoss()
854
- loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
855
- elif self.config.problem_type == "multi_label_classification":
856
- loss_fct = BCEWithLogitsLoss()
857
- loss = loss_fct(pooled_logits, labels)
858
- if not return_dict:
859
- output = (pooled_logits,) + transformer_outputs[1:]
860
- return ((loss,) + output) if loss is not None else output
861
-
862
- return SequenceClassifierOutputWithPast(
863
- loss=loss,
864
- logits=pooled_logits,
865
- past_key_values=transformer_outputs.past_key_values,
866
- hidden_states=transformer_outputs.hidden_states,
867
- attentions=transformer_outputs.attentions,
868
- )
 
1
  # coding=utf-8
2
  # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ # Copyright 2023 MobiLLaMA team.
4
+ # Copyright 2023 vLLM team.
5
  #
6
  # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
7
+ # and OPT implementations in this library. It has been modified for vLLM's model execution architecture.
 
 
8
  #
9
  # Licensed under the Apache License, Version 2.0 (the "License");
10
  # you may not use this file except in compliance with the License.
 
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
  # See the License for the specific language governing permissions and
19
  # limitations under the License.
20
+ """vLLM implementation of MobiLLaMA model."""
21
+
22
+ from typing import Dict, List, Optional, Tuple
23
+ from vllm.config import CacheConfig
24
+ from vllm.model_executor.layers.quantization import QuantizationConfig
25
 
26
  import torch
 
27
  from torch import nn
28
+ from transformers import LlamaConfig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ from vllm.model_executor.layers.activation import SiluAndMul
31
+ from vllm.attention import Attention
32
+ from vllm.model_executor.layers.linear import (QKVParallelLinear, RowParallelLinear,
33
+ ColumnParallelLinear)
34
+ from vllm.model_executor.layers.layernorm import RMSNorm
35
+ from vllm.model_executor.layers.rotary_embedding import get_rope
36
+ from vllm.model_executor.layers.sampler import Sampler
37
 
38
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
39
+ from vllm.model_executor.model_loader.weight_utils import (default_weight_loader,
40
+ pt_weights_iterator)
41
+ from vllm.model_executor.layers.sampler import SamplerOutput
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ class MobiLlamaAttention(nn.Module):
44
+ """Multi-headed attention from the paper 'Attention Is All You Need'"""
45
 
 
46
  def __init__(
47
  self,
48
+ config: LlamaConfig,
49
+ layer_idx: int,
 
50
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  super().__init__()
52
  self.config = config
53
+ self.layer_idx = layer_idx
54
+
55
  self.hidden_size = config.hidden_size
56
  self.num_heads = config.num_attention_heads
57
  self.head_dim = self.hidden_size // self.num_heads
58
  self.max_position_embeddings = config.max_position_embeddings
59
+ self.scaling = self.head_dim**-0.5
60
 
61
  if (self.head_dim * self.num_heads) != self.hidden_size:
62
  raise ValueError(
63
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
64
  f" and `num_heads`: {self.num_heads})."
65
  )
66
+
67
+ self.num_key_value_heads = getattr(config, "num_key_value_heads", self.num_heads)
68
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
69
+
70
+ # In vLLM implementation, we use combined QKV projection
71
+ self.qkv_proj = QKVParallelLinear(
72
+ self.hidden_size,
73
+ self.head_dim,
74
+ self.num_heads,
75
+ self.num_key_value_heads,
76
+ bias=False,
77
+ )
78
+
79
+ self.o_proj = RowParallelLinear(
80
+ self.hidden_size,
81
+ self.hidden_size,
82
+ bias=False,
83
+ )
84
+
85
+ # Set up rotary embedding
86
+ rope_theta = getattr(config, "rope_theta", 10000)
87
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
88
+ self.rotary_emb = get_rope(
89
+ self.head_dim,
90
+ rotary_dim=self.head_dim,
91
+ max_position=max_position_embeddings,
92
+ base=rope_theta,
93
+ )
94
+
95
+ self.attn = Attention(
96
+ self.num_heads,
97
+ self.head_dim,
98
+ self.scaling,
99
+ num_kv_heads=self.num_key_value_heads
100
+ )
101
 
102
  def forward(
103
  self,
104
+ positions: torch.Tensor,
105
  hidden_states: torch.Tensor,
106
+ kv_cache: torch.Tensor,
107
+ attn_metadata: Dict,
108
+ ) -> torch.Tensor:
109
+ qkv = self.qkv_proj(hidden_states)
110
+ q, k, v = qkv.split([
111
+ self.num_heads * self.head_dim,
112
+ self.num_key_value_heads * self.head_dim,
113
+ self.num_key_value_heads * self.head_dim
114
+ ], dim=-1)
115
+
116
+ # Reshape for rotary embedding
117
+ q = q.view(-1, self.num_heads, self.head_dim)
118
+ k = k.view(-1, self.num_key_value_heads, self.head_dim)
119
+ v = v.view(-1, self.num_key_value_heads, self.head_dim)
120
+
121
+ # Apply rotary embedding
122
+ q, k = self.rotary_emb(positions, q, k)
123
+
124
+ # Run attention
125
+ attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
126
+
127
+ # Reshape output and project back to hidden size
128
+ attn_output = attn_output.reshape(*hidden_states.shape[:-1], self.hidden_size)
 
 
 
 
 
 
 
 
 
 
 
129
  attn_output = self.o_proj(attn_output)
130
+
131
+ return attn_output
132
 
133
+ class MobiLlamaMLP(nn.Module):
134
+ def __init__(
135
+ self,
136
+ config: LlamaConfig,
137
+ ):
138
+ super().__init__()
139
+ hidden_size = config.hidden_size
140
+ intermediate_size = config.intermediate_size
141
+
142
+ # In vLLM, we use ColumnParallelLinear for gate_proj and up_proj
143
+ self.gate_proj = ColumnParallelLinear(
144
+ hidden_size,
145
+ intermediate_size,
146
+ bias=False,
147
+ )
148
+ self.up_proj = ColumnParallelLinear(
149
+ hidden_size,
150
+ intermediate_size,
151
+ bias=False,
152
+ )
153
+ self.down_proj = RowParallelLinear(
154
+ intermediate_size,
155
+ hidden_size,
156
+ bias=False,
157
+ )
158
+ self.act_fn = SiluAndMul()
159
 
160
+ def forward(self, x):
161
+ gate_output = self.gate_proj(x)
162
+ up_output = self.up_proj(x)
163
+
164
+ # Apply SiLU activation and multiply
165
+ intermediate_output = self.act_fn(gate_output, up_output)
166
+
167
+ # Project back to hidden size
168
+ output = self.down_proj(intermediate_output)
169
+ return output
170
 
171
  class MobiLlamaDecoderLayer(nn.Module):
172
+ def __init__(
173
+ self,
174
+ config: LlamaConfig,
175
+ layer_idx: int,
176
+ ):
177
  super().__init__()
178
  self.hidden_size = config.hidden_size
179
+
180
+ # Layer norms
181
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
182
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
183
+
184
+ # Self-attention
185
+ self.self_attn = MobiLlamaAttention(config=config, layer_idx=layer_idx)
186
+
187
+ # MLP
188
+ self.mlp = MobiLlamaMLP(config)
189
 
190
  def forward(
191
  self,
192
+ positions: torch.Tensor,
193
  hidden_states: torch.Tensor,
194
+ kv_cache: List[torch.Tensor],
195
+ attn_metadata: Dict,
196
+ ) -> torch.Tensor:
197
+ # Layernorm before self-attention
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  residual = hidden_states
 
199
  hidden_states = self.input_layernorm(hidden_states)
200
+
201
+ # Self-attention
202
+ hidden_states = self.self_attn(
203
+ positions=positions,
204
  hidden_states=hidden_states,
205
+ kv_cache=kv_cache[0],
206
+ attn_metadata=attn_metadata,
 
 
 
207
  )
208
+
209
+ # First residual connection
210
  hidden_states = residual + hidden_states
211
+
212
+ # Layernorm before MLP
213
  residual = hidden_states
214
  hidden_states = self.post_attention_layernorm(hidden_states)
215
+
216
+ # MLP
217
  hidden_states = self.mlp(hidden_states)
218
+
219
+ # Second residual connection
220
  hidden_states = residual + hidden_states
221
+
222
+ return hidden_states
223
 
224
+ class MobiLlamaModel(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  def __init__(self, config: LlamaConfig):
226
+ super().__init__()
227
+ self.config = config
228
  self.padding_idx = config.pad_token_id
229
  self.vocab_size = config.vocab_size
230
+
231
+ # Token embedding
232
+ self.embed_tokens = nn.Embedding(
233
+ config.vocab_size,
234
+ config.hidden_size,
235
+ self.padding_idx
236
+ )
237
+
238
+ # Decoder layers
239
+ self.layers = nn.ModuleList([
240
+ MobiLlamaDecoderLayer(config, i)
241
+ for i in range(config.num_hidden_layers)
242
+ ])
243
+
244
+ # Final layernorm
245
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  def forward(
248
  self,
249
+ input_ids: torch.Tensor,
250
+ positions: torch.Tensor,
251
+ kv_caches: List[List[torch.Tensor]],
252
+ attn_metadata: Dict,
253
+ ) -> torch.Tensor:
254
+ # Get token embeddings
255
+ hidden_states = self.embed_tokens(input_ids)
256
+
257
+ # Forward through each decoder layer
258
+ for i, layer in enumerate(self.layers):
259
+ hidden_states = layer(
260
+ positions=positions,
261
+ hidden_states=hidden_states,
262
+ kv_cache=kv_caches[i],
263
+ attn_metadata=attn_metadata,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  )
265
+
266
+ # Apply final layernorm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  hidden_states = self.norm(hidden_states)
268
+
269
+ return hidden_states
270
 
271
+ class MobiLlamaForCausalLM(nn.Module):
272
+ def __init__(self, config: LlamaConfig, quant_config: Optional[QuantizationConfig] = None,cache_config: Optional[CacheConfig] = None):
273
+ super().__init__()
274
+ self.config = config
275
+
276
+ # Core MobiLLaMA model
 
 
 
 
 
 
 
 
 
 
 
 
277
  self.model = MobiLlamaModel(config)
278
+
279
+ # LM head
280
+ self.lm_head = nn.Linear(
281
+ config.hidden_size,
282
+ config.vocab_size,
283
+ bias=False
284
+ )
285
+
286
+ # Sampling module
287
+ self.sampler = Sampler()
288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  def forward(
290
  self,
291
+ input_ids: torch.Tensor,
292
+ positions: torch.Tensor,
293
+ kv_caches: List[List[torch.Tensor]],
294
+ attn_metadata: Dict,
295
+ ) -> torch.Tensor:
296
+ hidden_states = self.model(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  input_ids=input_ids,
298
+ positions=positions,
299
+ kv_caches=kv_caches,
300
+ attn_metadata=attn_metadata,
 
 
 
 
 
301
  )
302
+
303
+ # Apply LM head
304
  logits = self.lm_head(hidden_states)
305
+
306
+ return logits
307
 
308
+ def sample(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  self,
310
+ logits: torch.Tensor,
311
+ sampling_metadata: SamplingMetadata,
312
+ ) -> SamplerOutput:
313
+ return self.sampler(logits, sampling_metadata)
314
+
315
+ def load_weights(self, weights):
316
+ # First use default loader for most weights
317
+ state_dict = self.state_dict()
318
+ for name, param in pt_weights_iterator(weights):
319
+ if "rotary_emb" in name:
320
+ # Skip rotary embedding weights as they're handled differently in vLLM
321
+ continue
322
+
323
+ # vLLM uses a combined QKV projection
324
+ if any(n in name for n in ["q_proj", "k_proj", "v_proj"]):
325
+ # These weights will be loaded separately through the QKVParallelLinear
326
+ continue
327
+
328
+ param_name = name
329
+
330
+ # Handle mapping between HF and vLLM naming schemes
331
+ if "self_attn.o_proj" in name:
332
+ param_name = name.replace("self_attn.o_proj", "self_attn.o_proj.weight")
333
+ elif "mlp.gate_proj" in name:
334
+ param_name = name.replace("mlp.gate_proj", "mlp.gate_proj.weight")
335
+ elif "mlp.up_proj" in name:
336
+ param_name = name.replace("mlp.up_proj", "mlp.up_proj.weight")
337
+ elif "mlp.down_proj" in name:
338
+ param_name = name.replace("mlp.down_proj", "mlp.down_proj.weight")
339
+
340
+ if param_name in state_dict:
341
+ state_dict[param_name].copy_(param)
342
+
343
+ # Separately handle the QKV projections to combine them
344
+ for idx, layer in enumerate(self.model.layers):
345
+ # Get weights for q_proj, k_proj, v_proj
346
+ q_weight = weights[f"model.layers.{idx}.self_attn.q_proj.weight"]
347
+ k_weight = weights[f"model.layers.{idx}.self_attn.k_proj.weight"]
348
+ v_weight = weights[f"model.layers.{idx}.self_attn.v_proj.weight"]
349
+
350
+ # Set the combined QKV weight
351
+ layer.self_attn.qkv_proj.weight.data.copy_(
352
+ torch.cat([q_weight, k_weight, v_weight], dim=0)
353
+ )