appleeji's picture
Upload cnets.py with huggingface_hub
5c1c6bc verified
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch LLaMA model."""
import math
from typing import List, Optional, Tuple, Union
from collections import Counter
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
import os
from transformers.integrations.deepspeed import HfDeepSpeedConfig
from transformers.activations import ACT2FN
from transformers import AutoTokenizer
from modeling_llama_kv import LlamaForCausalLM
from modeling_qwen_kv import Qwen3ForCausalLM
from configs import EConfig
from safetensors import safe_open
from datasets import load_dataset
import multiprocessing
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class LlamaRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size * 2, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size * 2, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size * 2, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self._init_rope()
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
)
elif scaling_type == "dynamic":
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
cache_hidden: Optional[List[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
lck = len(cache_hidden[0])
# cache_k = [self.k_proj(hidden) for hidden in cache_hidden]
# cache_v = [self.v_proj(hidden) for hidden in cache_hidden]
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(query_states, seq_len=q_len + lck)
cos, sin = cos.to(query_states.device), sin.to(query_states.device)
# query_states = apply_rotary_pos_emb(query_states, cos, sin, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids + lck)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# Avoid modify hidden cache inplace which will cause in-place modification error when enable gradient checkpoint.
# Return the updated hidden cache instead.
if cache_hidden is None:
local_cache_k = []
local_cache_v = []
else:
local_cache_k = list(cache_hidden[0])
local_cache_v = list(cache_hidden[1])
local_cache_k.append(key_states)
local_cache_v.append(value_states)
cache_k = local_cache_k
cache_v = local_cache_v
k0 = cache_k[0]
v0 = cache_v[0]
lck = len(cache_k)
attn_weights = torch.matmul(query_states, k0.transpose(2, 3)) / math.sqrt(self.head_dim)
attn_weights = attn_weights + attention_mask
for i in range(1, lck):
ki = cache_k[i]
qi = query_states
kiq = ki
attn_weightsi = (qi * kiq).sum(-1) / math.sqrt(self.head_dim)
attn_weights = torch.cat((attn_weights, attn_weightsi[..., None]), dim=-1)
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights0 = attn_weights[..., :q_len]
attn_output = torch.matmul(attn_weights0, v0)
for i in range(1, lck):
vi = cache_v[i]
attn_weightsi = attn_weights[..., q_len + i - 1]
attn_outputi = attn_weightsi[..., None] * vi
attn_output = attn_output + attn_outputi
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
# Return the updated hidden cache.
new_past_key_value = [local_cache_k,local_cache_v]
return attn_output, new_past_key_value
class LlamaMLP(nn.Module):
def __init__(self, config, last=True):
super().__init__()
self.last = last
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
# if last:
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
# else:
# self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size * 2, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
if self.config.pretraining_tp > 1:
slice = self.intermediate_size // self.config.pretraining_tp
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
gate_proj = torch.cat(
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
)
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
down_proj = [
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
]
down_proj = sum(down_proj)
else:
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class LlamaDecoderLayeremb(nn.Module):
def __init__(self, config, last=True):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LlamaAttention(config=config)
self.mlp = LlamaMLP(config, last=last)
self.last = last
# self.fc = nn.Linear(config.hidden_size * 2, config.hidden_size)
self.hidden_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# if self.index!=0:
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_emb: torch.Tensor,
hidden_states: torch.Tensor,
cache_hidden: [List[torch.Tensor]] = [],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
hidden_states = self.hidden_norm(hidden_states)
input_emb = self.input_layernorm(input_emb)
hidden_states = torch.cat((input_emb, hidden_states), dim=-1)
return_hidden = hidden_states
# cache_hidden.append(hidden_states)
# Self Attention
hidden_states, latest_hidden_cache = self.self_attn(
cache_hidden=cache_hidden,
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states, return_hidden)
return outputs, latest_hidden_cache
@torch.no_grad()
def padding(tensor, left=True):
zeropadding = torch.zeros_like(tensor[:, -1:])
if left:
tensor = torch.cat((zeropadding, tensor[:, :-1]), dim=1)
else:
tensor = torch.cat((tensor[:, 1:], zeropadding), dim=1)
return tensor
def process_data(data_chunk):
token_dict = Counter()
input_ids = data_chunk["input_ids"]
loss_mask = data_chunk["loss_mask"]
for i in range(len(input_ids)):
ids= input_ids[i][0]
mask = loss_mask[i][0]
for j in range(len(ids)):
if mask[j] == 1:
token_dict[ids[j]] += 1
return token_dict
def merge_dicts(dicts):
"""合并多个 Counter 字典"""
result = Counter()
for d in dicts:
result.update(d)
return result
class Model(nn.Module):
def __init__(self, config, ds_config, training_config, load_head=False, load_emb=True, path=None, model_type='llama'):
super().__init__()
self.model_type = model_type
# self.layers = nn.ModuleList(
# [LlamaDecoderLayer(config, index=index) for index in range(config.num_hidden_layers)])
self.train_config = training_config
# Settng dschf to allow efficient ZeRO-3 usage between hf and ds.
if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3:
dschf = HfDeepSpeedConfig(ds_config)
else:
dschf = None
self.midlayer = LlamaDecoderLayeremb(config)
self.gradient_checkpointing = self.train_config["gradient_checkpointing"]
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.hidden_size = config.hidden_size
self.draft_vocab_size = config.draft_vocab_size
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.length = 6 # Modified by ablation script
# Load target model based on model_type
if self.model_type == 'qwen3':
self.target_model = Qwen3ForCausalLM.from_pretrained(path, torch_dtype=torch.float16)
else: # default to llama
self.target_model = LlamaForCausalLM.from_pretrained(path, torch_dtype=torch.float16)
self.target_model.eval()
self.fc=nn.Linear(self.hidden_size*3, self.hidden_size, bias=False)
for param in self.target_model.parameters():
param.requires_grad = False
if not load_emb:
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
else:
from safetensors import safe_open
import json
import os
try:
with open(os.path.join(path, "model.safetensors.index.json"), "r") as f:
index_json = json.loads(f.read())
emb_path = index_json["weight_map"]["model.embed_tokens.weight"]
with safe_open(os.path.join(path, emb_path),
framework="pt",
device="cpu") as f:
tensor_slice = f.get_slice("model.embed_tokens.weight")
vocab_size, hidden_dim = tensor_slice.get_shape()
tensor = tensor_slice[:, :hidden_dim].float()
except:
with open(os.path.join(path, "pytorch_model.bin.index.json"), "r") as f:
index_json = json.loads(f.read())
emb_path = index_json["weight_map"]["model.embed_tokens.weight"]
weights = torch.load(os.path.join(path, emb_path))
tensor = weights["model.embed_tokens.weight"].float()
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, _weight=tensor)
self.lm_head = nn.Linear(config.hidden_size, config.draft_vocab_size, bias=False)
for param in self.embed_tokens.parameters():
param.requires_grad = False
def scandata(self, datapath, tokenizerpath):
N = self.draft_vocab_size
# [MODIFIED] Use different cache files for different model types
cache_file = f"cache_{self.model_type}.pt" if self.model_type != 'llama' else "cache.pt"
if not os.path.exists(cache_file):
tokenizer = AutoTokenizer.from_pretrained(tokenizerpath)
dataset = load_dataset('json', data_files=datapath)
dataset = dataset['train']
# dataset = dataset.select(range(96))
original_columns1 = dataset.column_names
num_proc = 1 # Changed from 48 to avoid DeepSpeed pickle issues
# [MODIFIED] Set separators based on model type
if self.model_type == 'qwen3':
sep = "<|im_end|>\n<|im_start|>assistant\n"
sep2 = "<|im_end|>\n<|im_start|>user\n"
else: # llama
sep = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sep2 = "<|eot_id|><|start_header_id|>user<|end_header_id|>"
def preprocess_function(examples):
new_examples = {
# "conversation": [],
"input_ids": [],
"loss_mask": []
}
for i in range(len(examples['id'])):
messages = [
{"role": "system",
"content": "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."},
]
convroles = ["user", "assistant"]
roles = {"human": "user", "gpt": "assistant"}
source = examples['conversations'][i]
if not source:
continue
if roles[source[0]["from"]] != "user":
# Skip the first one if it is not from human
source = source[1:]
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == convroles[j % 2], f"{i}"
# if sentence["from"]=="gpt":
# sentence["value"]=" "+sentence["value"]
messages.append(
{"role": role, "content": sentence["value"]}
)
conversation = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
)
if not tokenizer.pad_token_id:
tokenizer.pad_token_id = tokenizer.unk_token_id
input_ids = tokenizer(
conversation,
return_tensors="pt",
add_special_tokens=False,
).input_ids[0]
# When construct draft model vocab,
# filter out samples which is longer than max_len,
# instead of truncating them.
if len(input_ids) > self.train_config["max_len"]:
continue
loss_mask = torch.ones_like(input_ids)
# print(i)
total_len = len(input_ids)
turns = conversation.split(sep2)
# [MODIFIED] Skip samples with invalid conversation structure
if len(turns) < 2:
continue
turns[1] = turns[0] + sep2 + turns[1]
turns = turns[1:]
cur_len = 1
loss_mask[:cur_len] = 0
for i, turn in enumerate(turns):
if turn == "":
break
turn_len = len(tokenizer(turn).input_ids)
parts = turn.split(sep)
if len(parts) != 2:
break
parts[0] += sep
# "-2" is hardcoded for the Llama tokenizer to make the offset correct.
instruction_len = len(tokenizer(parts[0]).input_ids) - 1
# Ignore the user instructions
if i == 0:
loss_mask[cur_len: cur_len + instruction_len - 2] = 0
else:
loss_mask[cur_len - 3: cur_len + instruction_len + 1] = 0
cur_len += turn_len
if i != 0:
cur_len += 3
# cur_len+=2
# if i != 0 and not tokenizer.legacy:
# # The legacy and non-legacy modes handle special tokens differently
# cur_len -= 1
loss_mask[cur_len:] = 0
# new_examples["conversation"].append(conversation)
new_examples["input_ids"].append(input_ids[None, :])
new_examples["loss_mask"].append(loss_mask[None, :])
return new_examples
dataset = dataset.map(
preprocess_function,
batched=True,
num_proc=num_proc,
remove_columns=original_columns1,
load_from_cache_file=False
)
#dataset.set_format(type="torch")
# Process data sequentially to avoid DeepSpeed pickle issues
# (multiprocessing.Pool cannot pickle torch.distributed ProcessGroup)
chunks = [dataset[i:i + len(dataset)] for i in range(0, len(dataset), len(dataset))]
results = [process_data(chunk) for chunk in chunks]
# 合并结果
token_dict = merge_dicts(results)
total_frequency = sum(token_dict.values())
top_N = token_dict.most_common(N)
top_N_frequency_sum = sum(freq for key, freq in top_N)
top_N_ratio = top_N_frequency_sum / total_frequency
print(f"top {N} token frequency ratio: {top_N_ratio:.2%}")
used_tokens = [key for key, freq in top_N]
used_tokens.sort()
d2t = [used_tokens[i] - i for i in range(len(used_tokens))]
t2d = [i in used_tokens for i in range(self.vocab_size)]
d2t = torch.tensor(d2t)
t2d = torch.tensor(t2d)
cache = {
"d2t": d2t,
"t2d": t2d
}
torch.save(cache, cache_file)
else:
cache = torch.load(cache_file)
d2t = cache["d2t"]
t2d = cache["t2d"]
self.register_buffer("d2t", d2t)
self.register_buffer("t2d", t2d)
self.l1smooth = nn.SmoothL1Loss(reduction="none")
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
@torch.no_grad()
def dataprepare(self, input_ids, attention_mask, loss_mask):
device = input_ids.device
outs = self.target_model(input_ids=input_ids, attention_mask=attention_mask)
hidden_states0 = outs.hidden_states[0]
hidden_states1 = outs.hidden_states[1]
hidden_states2 = outs.hidden_states[2]
hidden_states=torch.cat((hidden_states0,hidden_states1,hidden_states2),dim=-1)
# hidden_states=torch.cat((hidden_states0,hidden_states1),dim=-1)
target = outs.logits
target = padding(target, left=False)
input_ids = padding(input_ids, left=False)
if target is not None:
target = target.to(device)
loss_mask = loss_mask[..., None]
loss_mask = loss_mask.to(device)
return hidden_states, target, loss_mask, input_ids
def forward(
self,
# hidden_states,
input_ids,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
loss_mask: Optional[torch.Tensor] = None,
):
hidden_states, target, loss_mask, input_ids = self.dataprepare(input_ids, attention_mask, loss_mask)
batch_size, seq_length, _ = hidden_states.shape
seq_length_with_past = seq_length
past_key_values_length = 0
# with torch.no_grad():
# inputs_embeds = self.embed_tokens(input_ids)
# inputs_embeds = inputs_embeds.detach()
if self.training and self.gradient_checkpointing and not hidden_states.requires_grad:
hidden_states.requires_grad = True
hidden_states=self.fc(hidden_states)
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = hidden_states.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
if self.gradient_checkpointing and self.training:
if use_cache:
use_cache = False
plosses = []
vlosses = []
acces = []
cache_hidden = [[], []]
for idx in range(self.length):
last = idx == self.length - 1
inputs_embeds = self.embed_tokens(input_ids)
if self.training and self.gradient_checkpointing and not inputs_embeds.requires_grad:
inputs_embeds.requires_grad = True
inputs_embeds = inputs_embeds.to(hidden_states.dtype)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, None, output_attentions)
return custom_forward
layer_outputs, cache_hidden = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.midlayer),
inputs_embeds,
hidden_states,
cache_hidden,
attention_mask,
position_ids,
)
else:
layer_outputs, cache_hidden = self.midlayer(
input_emb=inputs_embeds,
hidden_states=hidden_states,
cache_hidden=cache_hidden,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=None,
output_attentions=output_attentions,
use_cache=True,
)
hidden_states_out = layer_outputs[0]
# cache_hidden.append(layer_outputs[1])
# kv_cahce = layer_outputs[-1]
with torch.no_grad():
# hidden_states_target = padding(hidden_states, left=False)
target_head = target
target_max_token = target_head.argmax(-1)
# Move d2t to the same device as target_max_token
self.t2d = self.t2d.to(target_max_token.device)
target_mask = self.t2d[target_max_token]
target_mask = target_mask[..., None].int()
position_mask = target_mask * loss_mask
target_head = target_head[..., self.t2d]
target_head = target_head.float()
target_p = nn.Softmax(dim=2)(target_head)
target_p = target_p.detach()
hidden_states = hidden_states_out
hidden_states_out = self.norm(hidden_states_out)
logits = self.lm_head(hidden_states_out)
logits = logits.float()
out_logp = nn.LogSoftmax(dim=2)(logits)
plogp = target_p * out_logp
loss = -torch.sum(position_mask * plogp, 2).mean()
plosses.append(loss)
with torch.no_grad():
# Fixed: use position_mask.sum() instead of loss_mask.sum() for correct accuracy
acces.append(((logits.argmax(-1) == target_p.argmax(-1)) * position_mask.squeeze(-1)).sum().item() / (
position_mask.sum().item() + 1e-6))
if not last:
input_ids = padding(input_ids, left=False)
target = padding(target, left=False)
loss_mask = padding(loss_mask, left=False)
return plosses, vlosses, acces