| from transformers.models.llama.modeling_llama import LlamaForCausalLM |
| from transformers import MODEL_FOR_MASKED_LM_MAPPING |
| from transformers import PretrainedConfig |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
| from transformers import GPT2TokenizerFast |
| from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaForCausalLM |
| import torch |
| import torch.nn as nn |
| from typing import Optional, Tuple, List |
| import torch.nn.functional as F |
| from dataclasses import dataclass |
|
|
| @dataclass |
| class ModelOutput: |
| loss: float |
| logits: any |
|
|
| def rotate_half(x): |
| """Rotates half the hidden dims of the input.""" |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| Args: |
| q (`torch.Tensor`): The query tensor. |
| k (`torch.Tensor`): The key tensor. |
| cos (`torch.Tensor`): The cosine part of the rotary embedding. |
| sin (`torch.Tensor`): The sine part of the rotary embedding. |
| position_ids (`torch.Tensor`, *optional*): |
| Deprecated and unused. |
| unsqueeze_dim (`int`, *optional*, defaults to 1): |
| The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
| sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
| that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
| k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
| cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
| the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
| Returns: |
| `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
| """ |
| cos = cos.unsqueeze(unsqueeze_dim) |
| sin = sin.unsqueeze(unsqueeze_dim) |
| q_embed = (q * cos) + (rotate_half(q) * sin) |
| k_embed = (k * cos) + (rotate_half(k) * sin) |
| return q_embed, k_embed |
| |
| class CustomLlamaConfig(PretrainedConfig): |
| model_type = "custom_llama" |
| def __init__(self, **kwargs): |
| super().__init__(**kwargs) |
| |
|
|
| |
| class CustomLlamaAttention(LlamaAttention): |
| def __init__(self, config, layer_idx: int): |
| super().__init__(config, layer_idx) |
| self.num_heads = config.num_attention_heads |
| self.head_dim = config.hidden_size // config.num_attention_heads |
| self.scale = 1.0 / (self.head_dim ** 0.5) |
|
|
| self.w_q_start = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) |
| self.w_q_dir = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) |
| self.w_k_start = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) |
| self.w_k_dir = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) |
| self.w_v = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) |
| self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=config.attention_bias) |
|
|
| del self.q_proj, self.k_proj, self.v_proj |
|
|
| def _compute_metric(self, q_start, q_dir, k_start, k_dir): |
| std_term = torch.einsum("bhqd,bhkd->bhqk", q_start, k_start) |
| cross_term1 = torch.einsum("bhqd,bhkd->bhqk", q_dir, k_start) |
| cross_term2 = torch.einsum("bhqd,bhkd->bhqk", q_start, k_dir) |
| scores = (std_term + cross_term1 + cross_term2) * self.scale |
| |
| return scores |
|
|
| def _get_causal_mask(self, query_length, key_length, device): |
| return torch.triu( |
| torch.full((query_length, key_length), float('-inf'), device=device), |
| diagonal=1 |
| ).unsqueeze(0).unsqueeze(0) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| output_attentions: bool = False, |
| use_cache: bool = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| is_causal: Optional[torch.Tensor] = None, |
| ): |
| batch_size, seq_len, _ = hidden_states.size() |
|
|
| q_base = self.w_q_start(hidden_states) |
| q_dir = self.w_q_dir(hidden_states) - q_base |
| k_base = self.w_k_start(hidden_states) |
| k_dir = self.w_k_dir(hidden_states) - k_base |
| value = self.w_v(hidden_states) |
|
|
| |
| |
| |
| |
|
|
| q_start = q_base.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
| q_dir = q_dir.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
| k_start = k_base.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
| k_dir = k_dir.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
| value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
|
| cos, sin = position_embeddings |
| q_start, k_start = apply_rotary_pos_emb(q_start, k_start, cos, sin) |
| q_dir, k_dir = apply_rotary_pos_emb(q_dir, k_dir, cos, sin) |
|
|
| attn_scores = self._compute_metric(q_start, q_dir, k_start, k_dir) |
|
|
| |
| if attention_mask is not None: |
| padding_mask = (attention_mask == 0).view(batch_size, 1, 1, -1) |
| padding_mask = padding_mask.expand(-1, self.num_heads, -1, -1) |
| else: |
| padding_mask = None |
|
|
| if is_causal is not None: |
| causal_mask = self._get_causal_mask(seq_len, seq_len, attn_scores.device) |
| causal_mask = causal_mask.expand(batch_size, self.num_heads, -1, -1) |
| is_causal = is_causal.view(-1, 1, 1, 1) |
| combined_mask = torch.where(is_causal, causal_mask, 0.0) |
| else: |
| combined_mask = 0.0 |
|
|
| attn_scores = attn_scores + combined_mask |
| if padding_mask is not None: |
| attn_scores = attn_scores.masked_fill(padding_mask, torch.finfo(attn_scores.dtype).min) |
|
|
| attn_weights = F.softmax(attn_scores, dim=-1) |
| attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) |
|
|
| attn_output = torch.matmul(attn_weights, value) |
| attn_output = attn_output.transpose(1, 2).contiguous() |
| attn_output = attn_output.view(batch_size, seq_len, self.num_heads * self.head_dim) |
| attn_output = self.o_proj(attn_output) |
|
|
| if output_attentions: |
| return attn_output, attn_weights |
| return attn_output, None |
| class CustomLlamaDecoderLayer(LlamaDecoderLayer): |
| def __init__(self, config, layer_idx): |
| super().__init__(config, layer_idx) |
| self.self_attn = CustomLlamaAttention(config, layer_idx) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| output_attentions: Optional[bool] = False, |
| use_cache: Optional[bool] = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| is_causal: Optional[torch.Tensor] = None, |
| ): |
| residual = hidden_states |
| hidden_states = self.input_layernorm(hidden_states) |
| hidden_states, _ = self.self_attn( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_value, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| is_causal=is_causal, |
| ) |
| hidden_states = residual + hidden_states |
|
|
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| return hidden_states |
|
|
| class CustomLlamaModel(LlamaModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.layers = nn.ModuleList([ |
| CustomLlamaDecoderLayer(config, layer_idx=i) |
| for i in range(config.num_hidden_layers) |
| ]) |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| is_causal: Optional[torch.Tensor] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| ): |
| |
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) |
|
|
| hidden_states = inputs_embeds |
| if position_ids is None: |
| position_ids = (attention_mask.long().cumsum(dim=1) - 1).masked_fill(attention_mask == 0, 0) |
| cos, sin = self.rotary_emb(hidden_states, position_ids=position_ids) |
| |
|
|
| for layer in self.layers: |
| hidden_states = layer( |
| hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_values, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| position_embeddings=(cos, sin), |
| is_causal=is_causal, |
| ) |
|
|
| hidden_states = self.norm(hidden_states) |
| return hidden_states |
|
|
| class CustomLlamaForCausalLM(LlamaForCausalLM): |
| config_class = CustomLlamaConfig |
| def __init__(self, config): |
| super().__init__(config) |
| self.model = CustomLlamaModel(config) |
| self.post_init() |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| is_causal: Optional[torch.Tensor] = None, |
| ): |
| outputs = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| is_causal=is_causal, |
| ) |
|
|
| hidden_states = outputs |
| logits = self.lm_head(hidden_states) |
|
|
| loss = None |
| if labels is not None: |
| |
| is_causal = is_causal.to(labels.device) |
| |
| |
| causal_logits = logits[is_causal][..., :-1, :].contiguous() |
| causal_labels = labels[is_causal][..., 1:].contiguous() |
| |
| |
| masked_logits = logits[~is_causal][..., :-1, :].contiguous() |
| masked_labels = labels[~is_causal][..., 1:].contiguous() |
| |
| loss = 0.0 |
| if causal_logits.numel() > 0: |
| loss += F.cross_entropy( |
| causal_logits.view(-1, causal_logits.size(-1)), |
| causal_labels.view(-1), |
| ignore_index=-100 |
| ) |
| if masked_logits.numel() > 0: |
| loss += F.cross_entropy( |
| masked_logits.view(-1, masked_logits.size(-1)), |
| masked_labels.view(-1), |
| ignore_index=-100 |
| ) |
|
|
| return ModelOutput(loss=loss, logits=logits) |
|
|
| class CustomLlamaForMaskedLM(CustomLlamaForCausalLM): |
| config_class = CustomLlamaConfig |
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| **kwargs |
| ): |
| |
| batch_size = input_ids.size(0) if input_ids is not None else inputs_embeds.size(0) |
| is_causal = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device) |
|
|
| |
| return super().forward( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| labels=labels, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| is_causal=is_causal, |
| ) |
| from transformers import CONFIG_MAPPING, MODEL_MAPPING |
|
|
| CONFIG_MAPPING.update({"custom_llama": CustomLlamaConfig}) |
| MODEL_MAPPING.update({"custom_llama": CustomLlamaForMaskedLM}) |
| |
| def _register(): |
| from transformers import AutoConfig, AutoModelForCausalLM |
| AutoConfig.register("custom_llama", CustomLlamaConfig) |
| MODEL_FOR_MASKED_LM_MAPPING.register(CustomLlamaConfig, CustomLlamaForMaskedLM) |
|
|
| _register() |
|
|