|
|
from typing import List, Optional, Tuple, Union |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch.nn import CrossEntropyLoss |
|
|
from transformers.models.llama.modeling_llama import ( |
|
|
LlamaForCausalLM, |
|
|
CausalLMOutputWithPast, |
|
|
add_start_docstrings_to_model_forward, |
|
|
LLAMA_INPUTS_DOCSTRING, |
|
|
replace_return_docstrings, |
|
|
_CONFIG_FOR_DOC, |
|
|
LlamaModel, |
|
|
BaseModelOutputWithPast, |
|
|
logger, |
|
|
Cache, |
|
|
DynamicCache, |
|
|
StaticCache, |
|
|
repeat_kv, |
|
|
apply_rotary_pos_emb, |
|
|
LlamaSdpaAttention |
|
|
) |
|
|
|
|
|
from typing import List, Optional, Tuple, Union |
|
|
|
|
|
|
|
|
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) |
|
|
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) |
|
|
def LlamaForCausalLMforward( |
|
|
self, |
|
|
input_ids: torch.LongTensor = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
cot_start_idx: Optional[torch.LongTensor] = None, |
|
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
|
r""" |
|
|
Args: |
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
|
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
|
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
|
|
|
|
|
Returns: |
|
|
|
|
|
Example: |
|
|
|
|
|
```python |
|
|
>>> from transformers import AutoTokenizer, LlamaForCausalLM |
|
|
|
|
|
>>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") |
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") |
|
|
|
|
|
>>> prompt = "Hey, are you conscious? Can you talk to me?" |
|
|
>>> inputs = tokenizer(prompt, return_tensors="pt") |
|
|
|
|
|
>>> # Generate |
|
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
|
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
|
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
|
|
```""" |
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
) |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
|
|
|
outputs = self.model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
cache_position=cache_position, |
|
|
cot_start_idx = cot_start_idx, |
|
|
) |
|
|
|
|
|
hidden_states = outputs[0] |
|
|
if self.config.pretraining_tp > 1: |
|
|
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) |
|
|
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] |
|
|
logits = torch.cat(logits, dim=-1) |
|
|
else: |
|
|
logits = self.lm_head(hidden_states) |
|
|
logits = logits.float() |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
|
|
loss_fct = CrossEntropyLoss() |
|
|
shift_logits = shift_logits.view(-1, self.config.vocab_size) |
|
|
shift_labels = shift_labels.view(-1) |
|
|
|
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
|
loss = loss_fct(shift_logits, shift_labels) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) + outputs[1:] |
|
|
return (loss,) + output if loss is not None else output |
|
|
|
|
|
return CausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
|
|
|
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) |
|
|
def LlamaModelforward( |
|
|
self, |
|
|
input_ids: torch.LongTensor = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
cot_start_idx: Optional[torch.LongTensor] = None, |
|
|
) -> Union[Tuple, BaseModelOutputWithPast]: |
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
) |
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None): |
|
|
raise ValueError( |
|
|
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" |
|
|
) |
|
|
|
|
|
if self.gradient_checkpointing and self.training and use_cache: |
|
|
logger.warning_once( |
|
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." |
|
|
) |
|
|
use_cache = False |
|
|
|
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
|
|
|
past_seen_tokens = 0 |
|
|
if use_cache: |
|
|
if not isinstance(past_key_values, StaticCache): |
|
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values) |
|
|
past_seen_tokens = past_key_values.get_seq_length() |
|
|
|
|
|
if cache_position is None: |
|
|
if isinstance(past_key_values, StaticCache): |
|
|
raise ValueError("cache_position is a required argument when using StaticCache.") |
|
|
cache_position = torch.arange( |
|
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
|
|
) |
|
|
|
|
|
if position_ids is None: |
|
|
position_ids = cache_position.unsqueeze(0) |
|
|
|
|
|
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) |
|
|
|
|
|
|
|
|
if causal_mask is not None and cot_start_idx is not None: |
|
|
last_row = causal_mask[:, :, -1:, :].clone() |
|
|
cot_mask = torch.arange(causal_mask.shape[-2], device=causal_mask.device).view(1, 1, causal_mask.shape[-2], 1) >= cot_start_idx.view(causal_mask.shape[0], 1, 1, 1) |
|
|
new_mask = torch.where(cot_mask, last_row, causal_mask) |
|
|
causal_mask = new_mask |
|
|
|
|
|
|
|
|
|
|
|
hidden_states = inputs_embeds |
|
|
|
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
|
all_self_attns = () if output_attentions else None |
|
|
next_decoder_cache = None |
|
|
|
|
|
for decoder_layer in self.layers: |
|
|
if output_hidden_states: |
|
|
all_hidden_states += (hidden_states,) |
|
|
|
|
|
if self.gradient_checkpointing and self.training: |
|
|
layer_outputs = self._gradient_checkpointing_func( |
|
|
decoder_layer.__call__, |
|
|
hidden_states, |
|
|
causal_mask, |
|
|
position_ids, |
|
|
past_key_values, |
|
|
output_attentions, |
|
|
use_cache, |
|
|
cache_position, |
|
|
) |
|
|
else: |
|
|
layer_outputs = decoder_layer( |
|
|
hidden_states, |
|
|
attention_mask=causal_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_value=past_key_values, |
|
|
output_attentions=output_attentions, |
|
|
use_cache=use_cache, |
|
|
cache_position=cache_position, |
|
|
) |
|
|
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
|
|
if use_cache: |
|
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1] |
|
|
|
|
|
if output_attentions: |
|
|
all_self_attns += (layer_outputs[1],) |
|
|
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
|
|
|
|
|
|
if output_hidden_states: |
|
|
all_hidden_states += (hidden_states,) |
|
|
|
|
|
next_cache = None |
|
|
if use_cache: |
|
|
next_cache = ( |
|
|
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache |
|
|
) |
|
|
if not return_dict: |
|
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) |
|
|
return BaseModelOutputWithPast( |
|
|
last_hidden_state=hidden_states, |
|
|
past_key_values=next_cache, |
|
|
hidden_states=all_hidden_states, |
|
|
attentions=all_self_attns, |
|
|
) |
|
|
|
|
|
|
|
|
def LlamaSdpaAttentionforward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_value: Optional[Cache] = None, |
|
|
output_attentions: bool = False, |
|
|
use_cache: bool = False, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
|
if output_attentions: |
|
|
|
|
|
logger.warning_once( |
|
|
"LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " |
|
|
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' |
|
|
) |
|
|
return super().forward( |
|
|
hidden_states=hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_value=past_key_value, |
|
|
output_attentions=output_attentions, |
|
|
use_cache=use_cache, |
|
|
cache_position=cache_position, |
|
|
) |
|
|
|
|
|
bsz, q_len, _ = hidden_states.size() |
|
|
|
|
|
query_states = self.q_proj(hidden_states) |
|
|
key_states = self.k_proj(hidden_states) |
|
|
value_states = self.v_proj(hidden_states) |
|
|
|
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
cos, sin = self.rotary_emb(value_states, position_ids) |
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
|
|
|
|
|
|
past_key_value = getattr(self, "past_key_value", past_key_value) |
|
|
|
|
|
if past_key_value is not None: |
|
|
|
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
|
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups) |
|
|
value_states = repeat_kv(value_states, self.num_key_value_groups) |
|
|
|
|
|
causal_mask = attention_mask |
|
|
if attention_mask is not None: |
|
|
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] |
|
|
|
|
|
|
|
|
|
|
|
if query_states.device.type == "cuda" and causal_mask is not None: |
|
|
query_states = query_states.contiguous() |
|
|
key_states = key_states.contiguous() |
|
|
value_states = value_states.contiguous() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_output = torch.nn.functional.scaled_dot_product_attention( |
|
|
query_states, |
|
|
key_states, |
|
|
value_states, |
|
|
attn_mask=causal_mask, |
|
|
dropout_p=self.attention_dropout if self.training else 0.0, |
|
|
is_causal=False, |
|
|
) |
|
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
attn_output = attn_output.view(bsz, q_len, self.hidden_size) |
|
|
|
|
|
attn_output = self.o_proj(attn_output) |
|
|
|
|
|
return attn_output, None, past_key_value |
|
|
|
|
|
def monkey_patch_llama_forward(): |
|
|
LlamaForCausalLM.forward = LlamaForCausalLMforward |
|
|
LlamaModel.forward = LlamaModelforward |
|
|
LlamaSdpaAttention.forward = LlamaSdpaAttentionforward |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def monkey_patch_llama(): |
|
|
monkey_patch_llama_forward() |
|
|
print("Monkey patched Llama.") |
|
|
|