Upload cnets.py with huggingface_hub
Browse files
cnets.py
ADDED
|
@@ -0,0 +1,890 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 5 |
+
# and OPT implementations in this library. It has been modified from its
|
| 6 |
+
# original forms to accommodate minor architectural differences compared
|
| 7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
""" PyTorch LLaMA model."""
|
| 21 |
+
import math
|
| 22 |
+
from typing import List, Optional, Tuple, Union
|
| 23 |
+
from collections import Counter
|
| 24 |
+
import torch
|
| 25 |
+
import torch.nn.functional as F
|
| 26 |
+
import torch.utils.checkpoint
|
| 27 |
+
from torch import nn
|
| 28 |
+
import os
|
| 29 |
+
from transformers.integrations.deepspeed import HfDeepSpeedConfig
|
| 30 |
+
from transformers.activations import ACT2FN
|
| 31 |
+
from transformers import AutoTokenizer
|
| 32 |
+
from modeling_llama_kv import LlamaForCausalLM
|
| 33 |
+
from modeling_qwen_kv import Qwen3ForCausalLM
|
| 34 |
+
from configs import EConfig
|
| 35 |
+
from safetensors import safe_open
|
| 36 |
+
from datasets import load_dataset
|
| 37 |
+
import multiprocessing
|
| 38 |
+
|
| 39 |
+
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
| 40 |
+
def _make_causal_mask(
|
| 41 |
+
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
| 42 |
+
):
|
| 43 |
+
"""
|
| 44 |
+
Make causal mask used for bi-directional self-attention.
|
| 45 |
+
"""
|
| 46 |
+
bsz, tgt_len = input_ids_shape
|
| 47 |
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
| 48 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
| 49 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
| 50 |
+
mask = mask.to(dtype)
|
| 51 |
+
|
| 52 |
+
if past_key_values_length > 0:
|
| 53 |
+
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
| 54 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
| 58 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
| 59 |
+
"""
|
| 60 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
| 61 |
+
"""
|
| 62 |
+
bsz, src_len = mask.size()
|
| 63 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
| 64 |
+
|
| 65 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
| 66 |
+
|
| 67 |
+
inverted_mask = 1.0 - expanded_mask
|
| 68 |
+
|
| 69 |
+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 73 |
+
"""
|
| 74 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 75 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 76 |
+
"""
|
| 77 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 78 |
+
if n_rep == 1:
|
| 79 |
+
return hidden_states
|
| 80 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 81 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def rotate_half(x):
|
| 85 |
+
"""Rotates half the hidden dims of the input."""
|
| 86 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 87 |
+
x2 = x[..., x.shape[-1] // 2:]
|
| 88 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
| 92 |
+
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
| 93 |
+
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
| 94 |
+
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
| 95 |
+
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
| 96 |
+
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
| 97 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 98 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 99 |
+
return q_embed, k_embed
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class LlamaRotaryEmbedding(torch.nn.Module):
|
| 103 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| 104 |
+
super().__init__()
|
| 105 |
+
|
| 106 |
+
self.dim = dim
|
| 107 |
+
self.max_position_embeddings = max_position_embeddings
|
| 108 |
+
self.base = base
|
| 109 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
| 110 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 111 |
+
|
| 112 |
+
# Build here to make `torch.jit.trace` work.
|
| 113 |
+
self._set_cos_sin_cache(
|
| 114 |
+
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 118 |
+
self.max_seq_len_cached = seq_len
|
| 119 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
| 120 |
+
|
| 121 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 122 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 123 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 124 |
+
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
| 125 |
+
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
| 126 |
+
|
| 127 |
+
def forward(self, x, seq_len=None):
|
| 128 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
| 129 |
+
if seq_len > self.max_seq_len_cached:
|
| 130 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
| 131 |
+
|
| 132 |
+
return (
|
| 133 |
+
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
| 134 |
+
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
| 139 |
+
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
| 140 |
+
|
| 141 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
| 142 |
+
self.scaling_factor = scaling_factor
|
| 143 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
| 144 |
+
|
| 145 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 146 |
+
self.max_seq_len_cached = seq_len
|
| 147 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
| 148 |
+
t = t / self.scaling_factor
|
| 149 |
+
|
| 150 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 151 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 152 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 153 |
+
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
| 154 |
+
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
| 158 |
+
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
| 159 |
+
|
| 160 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
| 161 |
+
self.scaling_factor = scaling_factor
|
| 162 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
| 163 |
+
|
| 164 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 165 |
+
self.max_seq_len_cached = seq_len
|
| 166 |
+
|
| 167 |
+
if seq_len > self.max_position_embeddings:
|
| 168 |
+
base = self.base * (
|
| 169 |
+
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
| 170 |
+
) ** (self.dim / (self.dim - 2))
|
| 171 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
| 172 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 173 |
+
|
| 174 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
| 175 |
+
|
| 176 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 177 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 178 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 179 |
+
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
| 180 |
+
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class LlamaAttention(nn.Module):
|
| 185 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 186 |
+
|
| 187 |
+
def __init__(self, config):
|
| 188 |
+
super().__init__()
|
| 189 |
+
self.config = config
|
| 190 |
+
self.hidden_size = config.hidden_size
|
| 191 |
+
self.num_heads = config.num_attention_heads
|
| 192 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 193 |
+
self.num_key_value_heads = config.num_key_value_heads
|
| 194 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 195 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 196 |
+
|
| 197 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
| 198 |
+
raise ValueError(
|
| 199 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
| 200 |
+
f" and `num_heads`: {self.num_heads})."
|
| 201 |
+
)
|
| 202 |
+
self.q_proj = nn.Linear(self.hidden_size * 2, self.num_heads * self.head_dim, bias=False)
|
| 203 |
+
self.k_proj = nn.Linear(self.hidden_size * 2, self.num_key_value_heads * self.head_dim, bias=False)
|
| 204 |
+
self.v_proj = nn.Linear(self.hidden_size * 2, self.num_key_value_heads * self.head_dim, bias=False)
|
| 205 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
| 206 |
+
self._init_rope()
|
| 207 |
+
|
| 208 |
+
def _init_rope(self):
|
| 209 |
+
if self.config.rope_scaling is None:
|
| 210 |
+
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
|
| 211 |
+
else:
|
| 212 |
+
scaling_type = self.config.rope_scaling["type"]
|
| 213 |
+
scaling_factor = self.config.rope_scaling["factor"]
|
| 214 |
+
if scaling_type == "linear":
|
| 215 |
+
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
|
| 216 |
+
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
|
| 217 |
+
)
|
| 218 |
+
elif scaling_type == "dynamic":
|
| 219 |
+
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
|
| 220 |
+
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
|
| 221 |
+
)
|
| 222 |
+
else:
|
| 223 |
+
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
| 224 |
+
|
| 225 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
| 226 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 227 |
+
|
| 228 |
+
def forward(
|
| 229 |
+
self,
|
| 230 |
+
hidden_states: torch.Tensor,
|
| 231 |
+
cache_hidden: Optional[List[torch.Tensor]] = None,
|
| 232 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 233 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 234 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 235 |
+
output_attentions: bool = False,
|
| 236 |
+
use_cache: bool = False,
|
| 237 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 238 |
+
bsz, q_len, _ = hidden_states.size()
|
| 239 |
+
|
| 240 |
+
query_states = self.q_proj(hidden_states)
|
| 241 |
+
key_states = self.k_proj(hidden_states)
|
| 242 |
+
value_states = self.v_proj(hidden_states)
|
| 243 |
+
|
| 244 |
+
lck = len(cache_hidden[0])
|
| 245 |
+
|
| 246 |
+
# cache_k = [self.k_proj(hidden) for hidden in cache_hidden]
|
| 247 |
+
# cache_v = [self.v_proj(hidden) for hidden in cache_hidden]
|
| 248 |
+
|
| 249 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 250 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 251 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
cos, sin = self.rotary_emb(query_states, seq_len=q_len + lck)
|
| 255 |
+
cos, sin = cos.to(query_states.device), sin.to(query_states.device)
|
| 256 |
+
# query_states = apply_rotary_pos_emb(query_states, cos, sin, position_ids)
|
| 257 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids + lck)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 261 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 262 |
+
|
| 263 |
+
# Avoid modify hidden cache inplace which will cause in-place modification error when enable gradient checkpoint.
|
| 264 |
+
# Return the updated hidden cache instead.
|
| 265 |
+
if cache_hidden is None:
|
| 266 |
+
local_cache_k = []
|
| 267 |
+
local_cache_v = []
|
| 268 |
+
else:
|
| 269 |
+
local_cache_k = list(cache_hidden[0])
|
| 270 |
+
local_cache_v = list(cache_hidden[1])
|
| 271 |
+
|
| 272 |
+
local_cache_k.append(key_states)
|
| 273 |
+
local_cache_v.append(value_states)
|
| 274 |
+
|
| 275 |
+
cache_k = local_cache_k
|
| 276 |
+
cache_v = local_cache_v
|
| 277 |
+
|
| 278 |
+
k0 = cache_k[0]
|
| 279 |
+
v0 = cache_v[0]
|
| 280 |
+
lck = len(cache_k)
|
| 281 |
+
|
| 282 |
+
attn_weights = torch.matmul(query_states, k0.transpose(2, 3)) / math.sqrt(self.head_dim)
|
| 283 |
+
attn_weights = attn_weights + attention_mask
|
| 284 |
+
|
| 285 |
+
for i in range(1, lck):
|
| 286 |
+
ki = cache_k[i]
|
| 287 |
+
|
| 288 |
+
qi = query_states
|
| 289 |
+
kiq = ki
|
| 290 |
+
|
| 291 |
+
attn_weightsi = (qi * kiq).sum(-1) / math.sqrt(self.head_dim)
|
| 292 |
+
attn_weights = torch.cat((attn_weights, attn_weightsi[..., None]), dim=-1)
|
| 293 |
+
|
| 294 |
+
# upcast attention to fp32
|
| 295 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 296 |
+
attn_weights0 = attn_weights[..., :q_len]
|
| 297 |
+
|
| 298 |
+
attn_output = torch.matmul(attn_weights0, v0)
|
| 299 |
+
|
| 300 |
+
for i in range(1, lck):
|
| 301 |
+
vi = cache_v[i]
|
| 302 |
+
attn_weightsi = attn_weights[..., q_len + i - 1]
|
| 303 |
+
attn_outputi = attn_weightsi[..., None] * vi
|
| 304 |
+
attn_output = attn_output + attn_outputi
|
| 305 |
+
|
| 306 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 307 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 308 |
+
|
| 309 |
+
attn_output = self.o_proj(attn_output)
|
| 310 |
+
|
| 311 |
+
# Return the updated hidden cache.
|
| 312 |
+
new_past_key_value = [local_cache_k,local_cache_v]
|
| 313 |
+
return attn_output, new_past_key_value
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
class LlamaMLP(nn.Module):
|
| 317 |
+
def __init__(self, config, last=True):
|
| 318 |
+
super().__init__()
|
| 319 |
+
self.last = last
|
| 320 |
+
self.config = config
|
| 321 |
+
self.hidden_size = config.hidden_size
|
| 322 |
+
self.intermediate_size = config.intermediate_size
|
| 323 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 324 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 325 |
+
# if last:
|
| 326 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 327 |
+
# else:
|
| 328 |
+
# self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size * 2, bias=False)
|
| 329 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 330 |
+
|
| 331 |
+
def forward(self, x):
|
| 332 |
+
if self.config.pretraining_tp > 1:
|
| 333 |
+
slice = self.intermediate_size // self.config.pretraining_tp
|
| 334 |
+
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
|
| 335 |
+
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
|
| 336 |
+
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
|
| 337 |
+
|
| 338 |
+
gate_proj = torch.cat(
|
| 339 |
+
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
|
| 340 |
+
)
|
| 341 |
+
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
|
| 342 |
+
|
| 343 |
+
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
|
| 344 |
+
down_proj = [
|
| 345 |
+
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
|
| 346 |
+
]
|
| 347 |
+
down_proj = sum(down_proj)
|
| 348 |
+
else:
|
| 349 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 350 |
+
|
| 351 |
+
return down_proj
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
class LlamaRMSNorm(nn.Module):
|
| 355 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 356 |
+
"""
|
| 357 |
+
LlamaRMSNorm is equivalent to T5LayerNorm
|
| 358 |
+
"""
|
| 359 |
+
super().__init__()
|
| 360 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 361 |
+
self.variance_epsilon = eps
|
| 362 |
+
|
| 363 |
+
def forward(self, hidden_states):
|
| 364 |
+
input_dtype = hidden_states.dtype
|
| 365 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 366 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 367 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 368 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
class LlamaDecoderLayeremb(nn.Module):
|
| 372 |
+
def __init__(self, config, last=True):
|
| 373 |
+
super().__init__()
|
| 374 |
+
self.hidden_size = config.hidden_size
|
| 375 |
+
self.self_attn = LlamaAttention(config=config)
|
| 376 |
+
self.mlp = LlamaMLP(config, last=last)
|
| 377 |
+
self.last = last
|
| 378 |
+
# self.fc = nn.Linear(config.hidden_size * 2, config.hidden_size)
|
| 379 |
+
self.hidden_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 380 |
+
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 381 |
+
# if self.index!=0:
|
| 382 |
+
|
| 383 |
+
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 384 |
+
|
| 385 |
+
def forward(
|
| 386 |
+
self,
|
| 387 |
+
input_emb: torch.Tensor,
|
| 388 |
+
hidden_states: torch.Tensor,
|
| 389 |
+
cache_hidden: [List[torch.Tensor]] = [],
|
| 390 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 391 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 392 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 393 |
+
output_attentions: Optional[bool] = False,
|
| 394 |
+
use_cache: Optional[bool] = False,
|
| 395 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 396 |
+
"""
|
| 397 |
+
Args:
|
| 398 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 399 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
| 400 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
| 401 |
+
output_attentions (`bool`, *optional*):
|
| 402 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 403 |
+
returned tensors for more detail.
|
| 404 |
+
use_cache (`bool`, *optional*):
|
| 405 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 406 |
+
(see `past_key_values`).
|
| 407 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
| 408 |
+
"""
|
| 409 |
+
|
| 410 |
+
residual = hidden_states
|
| 411 |
+
|
| 412 |
+
hidden_states = self.hidden_norm(hidden_states)
|
| 413 |
+
input_emb = self.input_layernorm(input_emb)
|
| 414 |
+
|
| 415 |
+
hidden_states = torch.cat((input_emb, hidden_states), dim=-1)
|
| 416 |
+
|
| 417 |
+
return_hidden = hidden_states
|
| 418 |
+
|
| 419 |
+
# cache_hidden.append(hidden_states)
|
| 420 |
+
|
| 421 |
+
# Self Attention
|
| 422 |
+
hidden_states, latest_hidden_cache = self.self_attn(
|
| 423 |
+
cache_hidden=cache_hidden,
|
| 424 |
+
hidden_states=hidden_states,
|
| 425 |
+
attention_mask=attention_mask,
|
| 426 |
+
position_ids=position_ids,
|
| 427 |
+
past_key_value=past_key_value,
|
| 428 |
+
output_attentions=output_attentions,
|
| 429 |
+
use_cache=use_cache,
|
| 430 |
+
)
|
| 431 |
+
hidden_states = residual + hidden_states
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
residual = hidden_states
|
| 435 |
+
|
| 436 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 437 |
+
hidden_states = self.mlp(hidden_states)
|
| 438 |
+
hidden_states = residual + hidden_states
|
| 439 |
+
|
| 440 |
+
outputs = (hidden_states, return_hidden)
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
return outputs, latest_hidden_cache
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
@torch.no_grad()
|
| 447 |
+
def padding(tensor, left=True):
|
| 448 |
+
zeropadding = torch.zeros_like(tensor[:, -1:])
|
| 449 |
+
if left:
|
| 450 |
+
tensor = torch.cat((zeropadding, tensor[:, :-1]), dim=1)
|
| 451 |
+
else:
|
| 452 |
+
tensor = torch.cat((tensor[:, 1:], zeropadding), dim=1)
|
| 453 |
+
return tensor
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
def process_data(data_chunk):
|
| 457 |
+
|
| 458 |
+
token_dict = Counter()
|
| 459 |
+
input_ids = data_chunk["input_ids"]
|
| 460 |
+
loss_mask = data_chunk["loss_mask"]
|
| 461 |
+
for i in range(len(input_ids)):
|
| 462 |
+
ids= input_ids[i][0]
|
| 463 |
+
mask = loss_mask[i][0]
|
| 464 |
+
for j in range(len(ids)):
|
| 465 |
+
if mask[j] == 1:
|
| 466 |
+
token_dict[ids[j]] += 1
|
| 467 |
+
|
| 468 |
+
return token_dict
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def merge_dicts(dicts):
|
| 472 |
+
"""合并多个 Counter 字典"""
|
| 473 |
+
result = Counter()
|
| 474 |
+
for d in dicts:
|
| 475 |
+
result.update(d)
|
| 476 |
+
return result
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
class Model(nn.Module):
|
| 480 |
+
def __init__(self, config, ds_config, training_config, load_head=False, load_emb=True, path=None, model_type='llama'):
|
| 481 |
+
super().__init__()
|
| 482 |
+
self.model_type = model_type
|
| 483 |
+
# self.layers = nn.ModuleList(
|
| 484 |
+
# [LlamaDecoderLayer(config, index=index) for index in range(config.num_hidden_layers)])
|
| 485 |
+
self.train_config = training_config
|
| 486 |
+
# Settng dschf to allow efficient ZeRO-3 usage between hf and ds.
|
| 487 |
+
if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3:
|
| 488 |
+
dschf = HfDeepSpeedConfig(ds_config)
|
| 489 |
+
else:
|
| 490 |
+
dschf = None
|
| 491 |
+
self.midlayer = LlamaDecoderLayeremb(config)
|
| 492 |
+
self.gradient_checkpointing = self.train_config["gradient_checkpointing"]
|
| 493 |
+
self.padding_idx = config.pad_token_id
|
| 494 |
+
self.vocab_size = config.vocab_size
|
| 495 |
+
self.hidden_size = config.hidden_size
|
| 496 |
+
self.draft_vocab_size = config.draft_vocab_size
|
| 497 |
+
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 498 |
+
self.length = 6 # Modified by ablation script
|
| 499 |
+
|
| 500 |
+
# Load target model based on model_type
|
| 501 |
+
if self.model_type == 'qwen3':
|
| 502 |
+
self.target_model = Qwen3ForCausalLM.from_pretrained(path, torch_dtype=torch.float16)
|
| 503 |
+
else: # default to llama
|
| 504 |
+
self.target_model = LlamaForCausalLM.from_pretrained(path, torch_dtype=torch.float16)
|
| 505 |
+
|
| 506 |
+
self.target_model.eval()
|
| 507 |
+
self.fc=nn.Linear(self.hidden_size*3, self.hidden_size, bias=False)
|
| 508 |
+
for param in self.target_model.parameters():
|
| 509 |
+
param.requires_grad = False
|
| 510 |
+
|
| 511 |
+
if not load_emb:
|
| 512 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 513 |
+
|
| 514 |
+
else:
|
| 515 |
+
|
| 516 |
+
from safetensors import safe_open
|
| 517 |
+
import json
|
| 518 |
+
import os
|
| 519 |
+
try:
|
| 520 |
+
with open(os.path.join(path, "model.safetensors.index.json"), "r") as f:
|
| 521 |
+
index_json = json.loads(f.read())
|
| 522 |
+
emb_path = index_json["weight_map"]["model.embed_tokens.weight"]
|
| 523 |
+
with safe_open(os.path.join(path, emb_path),
|
| 524 |
+
framework="pt",
|
| 525 |
+
device="cpu") as f:
|
| 526 |
+
tensor_slice = f.get_slice("model.embed_tokens.weight")
|
| 527 |
+
vocab_size, hidden_dim = tensor_slice.get_shape()
|
| 528 |
+
tensor = tensor_slice[:, :hidden_dim].float()
|
| 529 |
+
except:
|
| 530 |
+
with open(os.path.join(path, "pytorch_model.bin.index.json"), "r") as f:
|
| 531 |
+
index_json = json.loads(f.read())
|
| 532 |
+
emb_path = index_json["weight_map"]["model.embed_tokens.weight"]
|
| 533 |
+
weights = torch.load(os.path.join(path, emb_path))
|
| 534 |
+
tensor = weights["model.embed_tokens.weight"].float()
|
| 535 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, _weight=tensor)
|
| 536 |
+
|
| 537 |
+
self.lm_head = nn.Linear(config.hidden_size, config.draft_vocab_size, bias=False)
|
| 538 |
+
|
| 539 |
+
for param in self.embed_tokens.parameters():
|
| 540 |
+
param.requires_grad = False
|
| 541 |
+
|
| 542 |
+
def scandata(self, datapath, tokenizerpath):
|
| 543 |
+
N = self.draft_vocab_size
|
| 544 |
+
|
| 545 |
+
# [MODIFIED] Use different cache files for different model types
|
| 546 |
+
cache_file = f"cache_{self.model_type}.pt" if self.model_type != 'llama' else "cache.pt"
|
| 547 |
+
|
| 548 |
+
if not os.path.exists(cache_file):
|
| 549 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizerpath)
|
| 550 |
+
dataset = load_dataset('json', data_files=datapath)
|
| 551 |
+
dataset = dataset['train']
|
| 552 |
+
# dataset = dataset.select(range(96))
|
| 553 |
+
original_columns1 = dataset.column_names
|
| 554 |
+
num_proc = 48
|
| 555 |
+
|
| 556 |
+
# [MODIFIED] Set separators based on model type
|
| 557 |
+
if self.model_type == 'qwen3':
|
| 558 |
+
sep = "<|im_end|>\n<|im_start|>assistant\n"
|
| 559 |
+
sep2 = "<|im_end|>\n<|im_start|>user\n"
|
| 560 |
+
else: # llama
|
| 561 |
+
sep = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
| 562 |
+
sep2 = "<|eot_id|><|start_header_id|>user<|end_header_id|>"
|
| 563 |
+
|
| 564 |
+
def preprocess_function(examples):
|
| 565 |
+
new_examples = {
|
| 566 |
+
# "conversation": [],
|
| 567 |
+
"input_ids": [],
|
| 568 |
+
"loss_mask": []
|
| 569 |
+
}
|
| 570 |
+
for i in range(len(examples['id'])):
|
| 571 |
+
messages = [
|
| 572 |
+
{"role": "system",
|
| 573 |
+
"content": "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."},
|
| 574 |
+
]
|
| 575 |
+
convroles = ["user", "assistant"]
|
| 576 |
+
roles = {"human": "user", "gpt": "assistant"}
|
| 577 |
+
source = examples['conversations'][i]
|
| 578 |
+
if not source:
|
| 579 |
+
continue
|
| 580 |
+
if roles[source[0]["from"]] != "user":
|
| 581 |
+
# Skip the first one if it is not from human
|
| 582 |
+
source = source[1:]
|
| 583 |
+
for j, sentence in enumerate(source):
|
| 584 |
+
role = roles[sentence["from"]]
|
| 585 |
+
assert role == convroles[j % 2], f"{i}"
|
| 586 |
+
# if sentence["from"]=="gpt":
|
| 587 |
+
# sentence["value"]=" "+sentence["value"]
|
| 588 |
+
messages.append(
|
| 589 |
+
{"role": role, "content": sentence["value"]}
|
| 590 |
+
)
|
| 591 |
+
conversation = tokenizer.apply_chat_template(
|
| 592 |
+
messages,
|
| 593 |
+
tokenize=False,
|
| 594 |
+
add_generation_prompt=False,
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
if not tokenizer.pad_token_id:
|
| 598 |
+
tokenizer.pad_token_id = tokenizer.unk_token_id
|
| 599 |
+
|
| 600 |
+
input_ids = tokenizer(
|
| 601 |
+
conversation,
|
| 602 |
+
return_tensors="pt",
|
| 603 |
+
add_special_tokens=False,
|
| 604 |
+
).input_ids[0]
|
| 605 |
+
# When construct draft model vocab,
|
| 606 |
+
# filter out samples which is longer than max_len,
|
| 607 |
+
# instead of truncating them.
|
| 608 |
+
if len(input_ids) > self.train_config["max_len"]:
|
| 609 |
+
continue
|
| 610 |
+
loss_mask = torch.ones_like(input_ids)
|
| 611 |
+
# print(i)
|
| 612 |
+
|
| 613 |
+
total_len = len(input_ids)
|
| 614 |
+
|
| 615 |
+
turns = conversation.split(sep2)
|
| 616 |
+
|
| 617 |
+
# [MODIFIED] Skip samples with invalid conversation structure
|
| 618 |
+
if len(turns) < 2:
|
| 619 |
+
continue
|
| 620 |
+
|
| 621 |
+
turns[1] = turns[0] + sep2 + turns[1]
|
| 622 |
+
turns = turns[1:]
|
| 623 |
+
|
| 624 |
+
cur_len = 1
|
| 625 |
+
loss_mask[:cur_len] = 0
|
| 626 |
+
for i, turn in enumerate(turns):
|
| 627 |
+
if turn == "":
|
| 628 |
+
break
|
| 629 |
+
turn_len = len(tokenizer(turn).input_ids)
|
| 630 |
+
|
| 631 |
+
parts = turn.split(sep)
|
| 632 |
+
if len(parts) != 2:
|
| 633 |
+
break
|
| 634 |
+
parts[0] += sep
|
| 635 |
+
# "-2" is hardcoded for the Llama tokenizer to make the offset correct.
|
| 636 |
+
instruction_len = len(tokenizer(parts[0]).input_ids) - 1
|
| 637 |
+
|
| 638 |
+
# Ignore the user instructions
|
| 639 |
+
if i == 0:
|
| 640 |
+
loss_mask[cur_len: cur_len + instruction_len - 2] = 0
|
| 641 |
+
else:
|
| 642 |
+
loss_mask[cur_len - 3: cur_len + instruction_len + 1] = 0
|
| 643 |
+
cur_len += turn_len
|
| 644 |
+
if i != 0:
|
| 645 |
+
cur_len += 3
|
| 646 |
+
# cur_len+=2
|
| 647 |
+
|
| 648 |
+
# if i != 0 and not tokenizer.legacy:
|
| 649 |
+
# # The legacy and non-legacy modes handle special tokens differently
|
| 650 |
+
# cur_len -= 1
|
| 651 |
+
|
| 652 |
+
loss_mask[cur_len:] = 0
|
| 653 |
+
|
| 654 |
+
# new_examples["conversation"].append(conversation)
|
| 655 |
+
new_examples["input_ids"].append(input_ids[None, :])
|
| 656 |
+
new_examples["loss_mask"].append(loss_mask[None, :])
|
| 657 |
+
|
| 658 |
+
return new_examples
|
| 659 |
+
|
| 660 |
+
dataset = dataset.map(
|
| 661 |
+
preprocess_function,
|
| 662 |
+
batched=True,
|
| 663 |
+
num_proc=num_proc,
|
| 664 |
+
remove_columns=original_columns1,
|
| 665 |
+
load_from_cache_file=False
|
| 666 |
+
)
|
| 667 |
+
#dataset.set_format(type="torch")
|
| 668 |
+
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
num_processes = num_proc
|
| 672 |
+
chunk_size = len(dataset) // num_processes + (len(dataset) % num_processes > 0)
|
| 673 |
+
chunks = [dataset[i:i + chunk_size] for i in range(0, len(dataset), chunk_size)]
|
| 674 |
+
|
| 675 |
+
# 创建进程池
|
| 676 |
+
with multiprocessing.Pool(num_processes) as pool:
|
| 677 |
+
# 并行处理数据块
|
| 678 |
+
results = pool.map(process_data, chunks)
|
| 679 |
+
|
| 680 |
+
# 合并结果
|
| 681 |
+
token_dict = merge_dicts(results)
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
total_frequency = sum(token_dict.values())
|
| 685 |
+
top_N = token_dict.most_common(N)
|
| 686 |
+
top_N_frequency_sum = sum(freq for key, freq in top_N)
|
| 687 |
+
top_N_ratio = top_N_frequency_sum / total_frequency
|
| 688 |
+
print(f"top {N} token frequency ratio: {top_N_ratio:.2%}")
|
| 689 |
+
used_tokens = [key for key, freq in top_N]
|
| 690 |
+
used_tokens.sort()
|
| 691 |
+
d2t = [used_tokens[i] - i for i in range(len(used_tokens))]
|
| 692 |
+
t2d = [i in used_tokens for i in range(self.vocab_size)]
|
| 693 |
+
d2t = torch.tensor(d2t)
|
| 694 |
+
t2d = torch.tensor(t2d)
|
| 695 |
+
cache = {
|
| 696 |
+
"d2t": d2t,
|
| 697 |
+
"t2d": t2d
|
| 698 |
+
}
|
| 699 |
+
torch.save(cache, cache_file)
|
| 700 |
+
else:
|
| 701 |
+
cache = torch.load(cache_file)
|
| 702 |
+
d2t = cache["d2t"]
|
| 703 |
+
t2d = cache["t2d"]
|
| 704 |
+
self.register_buffer("d2t", d2t)
|
| 705 |
+
self.register_buffer("t2d", t2d)
|
| 706 |
+
self.l1smooth = nn.SmoothL1Loss(reduction="none")
|
| 707 |
+
|
| 708 |
+
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
| 709 |
+
# create causal mask
|
| 710 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 711 |
+
combined_attention_mask = None
|
| 712 |
+
if input_shape[-1] > 1:
|
| 713 |
+
combined_attention_mask = _make_causal_mask(
|
| 714 |
+
input_shape,
|
| 715 |
+
inputs_embeds.dtype,
|
| 716 |
+
device=inputs_embeds.device,
|
| 717 |
+
past_key_values_length=past_key_values_length,
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
if attention_mask is not None:
|
| 721 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 722 |
+
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
|
| 723 |
+
inputs_embeds.device
|
| 724 |
+
)
|
| 725 |
+
combined_attention_mask = (
|
| 726 |
+
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
return combined_attention_mask
|
| 730 |
+
|
| 731 |
+
@torch.no_grad()
|
| 732 |
+
def dataprepare(self, input_ids, attention_mask, loss_mask):
|
| 733 |
+
device = input_ids.device
|
| 734 |
+
outs = self.target_model(input_ids=input_ids, attention_mask=attention_mask)
|
| 735 |
+
hidden_states0 = outs.hidden_states[0]
|
| 736 |
+
hidden_states1 = outs.hidden_states[1]
|
| 737 |
+
hidden_states2 = outs.hidden_states[2]
|
| 738 |
+
hidden_states=torch.cat((hidden_states0,hidden_states1,hidden_states2),dim=-1)
|
| 739 |
+
# hidden_states=torch.cat((hidden_states0,hidden_states1),dim=-1)
|
| 740 |
+
target = outs.logits
|
| 741 |
+
target = padding(target, left=False)
|
| 742 |
+
input_ids = padding(input_ids, left=False)
|
| 743 |
+
|
| 744 |
+
if target is not None:
|
| 745 |
+
target = target.to(device)
|
| 746 |
+
loss_mask = loss_mask[..., None]
|
| 747 |
+
loss_mask = loss_mask.to(device)
|
| 748 |
+
|
| 749 |
+
return hidden_states, target, loss_mask, input_ids
|
| 750 |
+
|
| 751 |
+
def forward(
|
| 752 |
+
self,
|
| 753 |
+
# hidden_states,
|
| 754 |
+
input_ids,
|
| 755 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 756 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 757 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 758 |
+
use_cache: Optional[bool] = None,
|
| 759 |
+
output_attentions: Optional[bool] = None,
|
| 760 |
+
output_hidden_states: Optional[bool] = None,
|
| 761 |
+
loss_mask: Optional[torch.Tensor] = None,
|
| 762 |
+
|
| 763 |
+
):
|
| 764 |
+
hidden_states, target, loss_mask, input_ids = self.dataprepare(input_ids, attention_mask, loss_mask)
|
| 765 |
+
|
| 766 |
+
batch_size, seq_length, _ = hidden_states.shape
|
| 767 |
+
seq_length_with_past = seq_length
|
| 768 |
+
past_key_values_length = 0
|
| 769 |
+
|
| 770 |
+
# with torch.no_grad():
|
| 771 |
+
# inputs_embeds = self.embed_tokens(input_ids)
|
| 772 |
+
# inputs_embeds = inputs_embeds.detach()
|
| 773 |
+
|
| 774 |
+
if self.training and self.gradient_checkpointing and not hidden_states.requires_grad:
|
| 775 |
+
hidden_states.requires_grad = True
|
| 776 |
+
|
| 777 |
+
hidden_states=self.fc(hidden_states)
|
| 778 |
+
|
| 779 |
+
if past_key_values is not None:
|
| 780 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
| 781 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
| 782 |
+
if position_ids is None:
|
| 783 |
+
device = hidden_states.device
|
| 784 |
+
position_ids = torch.arange(
|
| 785 |
+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
| 786 |
+
)
|
| 787 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
| 788 |
+
else:
|
| 789 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
| 790 |
+
|
| 791 |
+
if attention_mask is None:
|
| 792 |
+
attention_mask = torch.ones(
|
| 793 |
+
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
|
| 794 |
+
)
|
| 795 |
+
attention_mask = self._prepare_decoder_attention_mask(
|
| 796 |
+
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
if self.gradient_checkpointing and self.training:
|
| 800 |
+
if use_cache:
|
| 801 |
+
use_cache = False
|
| 802 |
+
|
| 803 |
+
plosses = []
|
| 804 |
+
vlosses = []
|
| 805 |
+
acces = []
|
| 806 |
+
cache_hidden = [[], []]
|
| 807 |
+
|
| 808 |
+
for idx in range(self.length):
|
| 809 |
+
last = idx == self.length - 1
|
| 810 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 811 |
+
if self.training and self.gradient_checkpointing and not inputs_embeds.requires_grad:
|
| 812 |
+
inputs_embeds.requires_grad = True
|
| 813 |
+
inputs_embeds = inputs_embeds.to(hidden_states.dtype)
|
| 814 |
+
|
| 815 |
+
if self.gradient_checkpointing and self.training:
|
| 816 |
+
|
| 817 |
+
def create_custom_forward(module):
|
| 818 |
+
def custom_forward(*inputs):
|
| 819 |
+
# None for past_key_value
|
| 820 |
+
return module(*inputs, None, output_attentions)
|
| 821 |
+
|
| 822 |
+
return custom_forward
|
| 823 |
+
|
| 824 |
+
layer_outputs, cache_hidden = torch.utils.checkpoint.checkpoint(
|
| 825 |
+
create_custom_forward(self.midlayer),
|
| 826 |
+
inputs_embeds,
|
| 827 |
+
hidden_states,
|
| 828 |
+
cache_hidden,
|
| 829 |
+
attention_mask,
|
| 830 |
+
position_ids,
|
| 831 |
+
)
|
| 832 |
+
else:
|
| 833 |
+
|
| 834 |
+
layer_outputs, cache_hidden = self.midlayer(
|
| 835 |
+
input_emb=inputs_embeds,
|
| 836 |
+
hidden_states=hidden_states,
|
| 837 |
+
cache_hidden=cache_hidden,
|
| 838 |
+
attention_mask=attention_mask,
|
| 839 |
+
position_ids=position_ids,
|
| 840 |
+
past_key_value=None,
|
| 841 |
+
output_attentions=output_attentions,
|
| 842 |
+
use_cache=True,
|
| 843 |
+
)
|
| 844 |
+
|
| 845 |
+
hidden_states_out = layer_outputs[0]
|
| 846 |
+
# cache_hidden.append(layer_outputs[1])
|
| 847 |
+
# kv_cahce = layer_outputs[-1]
|
| 848 |
+
|
| 849 |
+
with torch.no_grad():
|
| 850 |
+
# hidden_states_target = padding(hidden_states, left=False)
|
| 851 |
+
target_head = target
|
| 852 |
+
target_max_token = target_head.argmax(-1)
|
| 853 |
+
# Move d2t to the same device as target_max_token
|
| 854 |
+
self.t2d = self.t2d.to(target_max_token.device)
|
| 855 |
+
target_mask = self.t2d[target_max_token]
|
| 856 |
+
target_mask = target_mask[..., None].int()
|
| 857 |
+
position_mask = target_mask * loss_mask
|
| 858 |
+
target_head = target_head[..., self.t2d]
|
| 859 |
+
target_head = target_head.float()
|
| 860 |
+
target_p = nn.Softmax(dim=2)(target_head)
|
| 861 |
+
target_p = target_p.detach()
|
| 862 |
+
|
| 863 |
+
|
| 864 |
+
|
| 865 |
+
hidden_states = hidden_states_out
|
| 866 |
+
|
| 867 |
+
hidden_states_out = self.norm(hidden_states_out)
|
| 868 |
+
|
| 869 |
+
logits = self.lm_head(hidden_states_out)
|
| 870 |
+
logits = logits.float()
|
| 871 |
+
out_logp = nn.LogSoftmax(dim=2)(logits)
|
| 872 |
+
plogp = target_p * out_logp
|
| 873 |
+
loss = -torch.sum(position_mask * plogp, 2).mean()
|
| 874 |
+
plosses.append(loss)
|
| 875 |
+
with torch.no_grad():
|
| 876 |
+
acces.append(((logits.argmax(-1) == target_p.argmax(-1)) * position_mask.squeeze(-1)).sum().item() / (
|
| 877 |
+
loss_mask.sum().item() + 1e-6))
|
| 878 |
+
|
| 879 |
+
if not last:
|
| 880 |
+
input_ids = padding(input_ids, left=False)
|
| 881 |
+
target = padding(target, left=False)
|
| 882 |
+
loss_mask = padding(loss_mask, left=False)
|
| 883 |
+
|
| 884 |
+
|
| 885 |
+
|
| 886 |
+
return plosses, vlosses, acces
|
| 887 |
+
|
| 888 |
+
|
| 889 |
+
|
| 890 |
+
|