|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" PyTorch LLaMA model.""" |
|
|
import math |
|
|
from typing import List, Optional, Tuple, Union |
|
|
from collections import Counter |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import torch.utils.checkpoint |
|
|
from torch import nn |
|
|
import os |
|
|
from transformers.integrations.deepspeed import HfDeepSpeedConfig |
|
|
from transformers.activations import ACT2FN |
|
|
from transformers import AutoTokenizer |
|
|
from modeling_llama_kv import LlamaForCausalLM |
|
|
from modeling_qwen_kv import Qwen3ForCausalLM |
|
|
from configs import EConfig |
|
|
from safetensors import safe_open |
|
|
from datasets import load_dataset |
|
|
import multiprocessing |
|
|
|
|
|
|
|
|
def _make_causal_mask( |
|
|
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 |
|
|
): |
|
|
""" |
|
|
Make causal mask used for bi-directional self-attention. |
|
|
""" |
|
|
bsz, tgt_len = input_ids_shape |
|
|
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) |
|
|
mask_cond = torch.arange(mask.size(-1), device=device) |
|
|
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) |
|
|
mask = mask.to(dtype) |
|
|
|
|
|
if past_key_values_length > 0: |
|
|
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) |
|
|
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) |
|
|
|
|
|
|
|
|
|
|
|
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): |
|
|
""" |
|
|
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. |
|
|
""" |
|
|
bsz, src_len = mask.size() |
|
|
tgt_len = tgt_len if tgt_len is not None else src_len |
|
|
|
|
|
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) |
|
|
|
|
|
inverted_mask = 1.0 - expanded_mask |
|
|
|
|
|
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) |
|
|
|
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
|
""" |
|
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
|
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
|
|
""" |
|
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
|
|
if n_rep == 1: |
|
|
return hidden_states |
|
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
|
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
|
|
|
|
def rotate_half(x): |
|
|
"""Rotates half the hidden dims of the input.""" |
|
|
x1 = x[..., : x.shape[-1] // 2] |
|
|
x2 = x[..., x.shape[-1] // 2:] |
|
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
|
|
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids): |
|
|
|
|
|
cos = cos.squeeze(1).squeeze(0) |
|
|
sin = sin.squeeze(1).squeeze(0) |
|
|
cos = cos[position_ids].unsqueeze(1) |
|
|
sin = sin[position_ids].unsqueeze(1) |
|
|
q_embed = (q * cos) + (rotate_half(q) * sin) |
|
|
k_embed = (k * cos) + (rotate_half(k) * sin) |
|
|
return q_embed, k_embed |
|
|
|
|
|
|
|
|
class LlamaRotaryEmbedding(torch.nn.Module): |
|
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): |
|
|
super().__init__() |
|
|
|
|
|
self.dim = dim |
|
|
self.max_position_embeddings = max_position_embeddings |
|
|
self.base = base |
|
|
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) |
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
|
|
|
|
|
|
self._set_cos_sin_cache( |
|
|
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() |
|
|
) |
|
|
|
|
|
def _set_cos_sin_cache(self, seq_len, device, dtype): |
|
|
self.max_seq_len_cached = seq_len |
|
|
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) |
|
|
|
|
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
|
|
|
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
|
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) |
|
|
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) |
|
|
|
|
|
def forward(self, x, seq_len=None): |
|
|
|
|
|
if seq_len > self.max_seq_len_cached: |
|
|
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) |
|
|
|
|
|
return ( |
|
|
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), |
|
|
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), |
|
|
) |
|
|
|
|
|
|
|
|
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): |
|
|
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" |
|
|
|
|
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): |
|
|
self.scaling_factor = scaling_factor |
|
|
super().__init__(dim, max_position_embeddings, base, device) |
|
|
|
|
|
def _set_cos_sin_cache(self, seq_len, device, dtype): |
|
|
self.max_seq_len_cached = seq_len |
|
|
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) |
|
|
t = t / self.scaling_factor |
|
|
|
|
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
|
|
|
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
|
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) |
|
|
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) |
|
|
|
|
|
|
|
|
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): |
|
|
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" |
|
|
|
|
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): |
|
|
self.scaling_factor = scaling_factor |
|
|
super().__init__(dim, max_position_embeddings, base, device) |
|
|
|
|
|
def _set_cos_sin_cache(self, seq_len, device, dtype): |
|
|
self.max_seq_len_cached = seq_len |
|
|
|
|
|
if seq_len > self.max_position_embeddings: |
|
|
base = self.base * ( |
|
|
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) |
|
|
) ** (self.dim / (self.dim - 2)) |
|
|
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) |
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
|
|
|
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) |
|
|
|
|
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
|
|
|
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
|
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) |
|
|
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) |
|
|
|
|
|
|
|
|
|
|
|
class LlamaAttention(nn.Module): |
|
|
"""Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.hidden_size = config.hidden_size |
|
|
self.num_heads = config.num_attention_heads |
|
|
self.head_dim = self.hidden_size // self.num_heads |
|
|
self.num_key_value_heads = config.num_key_value_heads |
|
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
|
|
self.max_position_embeddings = config.max_position_embeddings |
|
|
|
|
|
if (self.head_dim * self.num_heads) != self.hidden_size: |
|
|
raise ValueError( |
|
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" |
|
|
f" and `num_heads`: {self.num_heads})." |
|
|
) |
|
|
self.q_proj = nn.Linear(self.hidden_size * 2, self.num_heads * self.head_dim, bias=False) |
|
|
self.k_proj = nn.Linear(self.hidden_size * 2, self.num_key_value_heads * self.head_dim, bias=False) |
|
|
self.v_proj = nn.Linear(self.hidden_size * 2, self.num_key_value_heads * self.head_dim, bias=False) |
|
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) |
|
|
self._init_rope() |
|
|
|
|
|
def _init_rope(self): |
|
|
if self.config.rope_scaling is None: |
|
|
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) |
|
|
else: |
|
|
scaling_type = self.config.rope_scaling["type"] |
|
|
scaling_factor = self.config.rope_scaling["factor"] |
|
|
if scaling_type == "linear": |
|
|
self.rotary_emb = LlamaLinearScalingRotaryEmbedding( |
|
|
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor |
|
|
) |
|
|
elif scaling_type == "dynamic": |
|
|
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( |
|
|
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"Unknown RoPE scaling type {scaling_type}") |
|
|
|
|
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): |
|
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
cache_hidden: Optional[List[torch.Tensor]] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
|
output_attentions: bool = False, |
|
|
use_cache: bool = False, |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
|
bsz, q_len, _ = hidden_states.size() |
|
|
|
|
|
query_states = self.q_proj(hidden_states) |
|
|
key_states = self.k_proj(hidden_states) |
|
|
value_states = self.v_proj(hidden_states) |
|
|
|
|
|
lck = len(cache_hidden[0]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
|
cos, sin = self.rotary_emb(query_states, seq_len=q_len + lck) |
|
|
cos, sin = cos.to(query_states.device), sin.to(query_states.device) |
|
|
|
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids + lck) |
|
|
|
|
|
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups) |
|
|
value_states = repeat_kv(value_states, self.num_key_value_groups) |
|
|
|
|
|
|
|
|
|
|
|
if cache_hidden is None: |
|
|
local_cache_k = [] |
|
|
local_cache_v = [] |
|
|
else: |
|
|
local_cache_k = list(cache_hidden[0]) |
|
|
local_cache_v = list(cache_hidden[1]) |
|
|
|
|
|
local_cache_k.append(key_states) |
|
|
local_cache_v.append(value_states) |
|
|
|
|
|
cache_k = local_cache_k |
|
|
cache_v = local_cache_v |
|
|
|
|
|
k0 = cache_k[0] |
|
|
v0 = cache_v[0] |
|
|
lck = len(cache_k) |
|
|
|
|
|
attn_weights = torch.matmul(query_states, k0.transpose(2, 3)) / math.sqrt(self.head_dim) |
|
|
attn_weights = attn_weights + attention_mask |
|
|
|
|
|
for i in range(1, lck): |
|
|
ki = cache_k[i] |
|
|
|
|
|
qi = query_states |
|
|
kiq = ki |
|
|
|
|
|
attn_weightsi = (qi * kiq).sum(-1) / math.sqrt(self.head_dim) |
|
|
attn_weights = torch.cat((attn_weights, attn_weightsi[..., None]), dim=-1) |
|
|
|
|
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
|
|
attn_weights0 = attn_weights[..., :q_len] |
|
|
|
|
|
attn_output = torch.matmul(attn_weights0, v0) |
|
|
|
|
|
for i in range(1, lck): |
|
|
vi = cache_v[i] |
|
|
attn_weightsi = attn_weights[..., q_len + i - 1] |
|
|
attn_outputi = attn_weightsi[..., None] * vi |
|
|
attn_output = attn_output + attn_outputi |
|
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) |
|
|
|
|
|
attn_output = self.o_proj(attn_output) |
|
|
|
|
|
|
|
|
new_past_key_value = [local_cache_k,local_cache_v] |
|
|
return attn_output, new_past_key_value |
|
|
|
|
|
|
|
|
class LlamaMLP(nn.Module): |
|
|
def __init__(self, config, last=True): |
|
|
super().__init__() |
|
|
self.last = last |
|
|
self.config = config |
|
|
self.hidden_size = config.hidden_size |
|
|
self.intermediate_size = config.intermediate_size |
|
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
|
|
|
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
|
|
|
|
|
|
|
|
self.act_fn = ACT2FN[config.hidden_act] |
|
|
|
|
|
def forward(self, x): |
|
|
if self.config.pretraining_tp > 1: |
|
|
slice = self.intermediate_size // self.config.pretraining_tp |
|
|
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) |
|
|
up_proj_slices = self.up_proj.weight.split(slice, dim=0) |
|
|
down_proj_slices = self.down_proj.weight.split(slice, dim=1) |
|
|
|
|
|
gate_proj = torch.cat( |
|
|
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 |
|
|
) |
|
|
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) |
|
|
|
|
|
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) |
|
|
down_proj = [ |
|
|
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) |
|
|
] |
|
|
down_proj = sum(down_proj) |
|
|
else: |
|
|
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
|
return down_proj |
|
|
|
|
|
|
|
|
class LlamaRMSNorm(nn.Module): |
|
|
def __init__(self, hidden_size, eps=1e-6): |
|
|
""" |
|
|
LlamaRMSNorm is equivalent to T5LayerNorm |
|
|
""" |
|
|
super().__init__() |
|
|
self.weight = nn.Parameter(torch.ones(hidden_size)) |
|
|
self.variance_epsilon = eps |
|
|
|
|
|
def forward(self, hidden_states): |
|
|
input_dtype = hidden_states.dtype |
|
|
hidden_states = hidden_states.to(torch.float32) |
|
|
variance = hidden_states.pow(2).mean(-1, keepdim=True) |
|
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
|
|
return self.weight * hidden_states.to(input_dtype) |
|
|
|
|
|
|
|
|
class LlamaDecoderLayeremb(nn.Module): |
|
|
def __init__(self, config, last=True): |
|
|
super().__init__() |
|
|
self.hidden_size = config.hidden_size |
|
|
self.self_attn = LlamaAttention(config=config) |
|
|
self.mlp = LlamaMLP(config, last=last) |
|
|
self.last = last |
|
|
|
|
|
self.hidden_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
|
|
|
|
|
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_emb: torch.Tensor, |
|
|
hidden_states: torch.Tensor, |
|
|
cache_hidden: [List[torch.Tensor]] = [], |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
|
output_attentions: Optional[bool] = False, |
|
|
use_cache: Optional[bool] = False, |
|
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: |
|
|
""" |
|
|
Args: |
|
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
|
|
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size |
|
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. |
|
|
output_attentions (`bool`, *optional*): |
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
|
|
returned tensors for more detail. |
|
|
use_cache (`bool`, *optional*): |
|
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding |
|
|
(see `past_key_values`). |
|
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states |
|
|
""" |
|
|
|
|
|
residual = hidden_states |
|
|
|
|
|
hidden_states = self.hidden_norm(hidden_states) |
|
|
input_emb = self.input_layernorm(input_emb) |
|
|
|
|
|
hidden_states = torch.cat((input_emb, hidden_states), dim=-1) |
|
|
|
|
|
return_hidden = hidden_states |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hidden_states, latest_hidden_cache = self.self_attn( |
|
|
cache_hidden=cache_hidden, |
|
|
hidden_states=hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_value=past_key_value, |
|
|
output_attentions=output_attentions, |
|
|
use_cache=use_cache, |
|
|
) |
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
|
|
|
residual = hidden_states |
|
|
|
|
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
|
hidden_states = self.mlp(hidden_states) |
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
outputs = (hidden_states, return_hidden) |
|
|
|
|
|
|
|
|
return outputs, latest_hidden_cache |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def padding(tensor, left=True): |
|
|
zeropadding = torch.zeros_like(tensor[:, -1:]) |
|
|
if left: |
|
|
tensor = torch.cat((zeropadding, tensor[:, :-1]), dim=1) |
|
|
else: |
|
|
tensor = torch.cat((tensor[:, 1:], zeropadding), dim=1) |
|
|
return tensor |
|
|
|
|
|
|
|
|
def process_data(data_chunk): |
|
|
|
|
|
token_dict = Counter() |
|
|
input_ids = data_chunk["input_ids"] |
|
|
loss_mask = data_chunk["loss_mask"] |
|
|
for i in range(len(input_ids)): |
|
|
ids= input_ids[i][0] |
|
|
mask = loss_mask[i][0] |
|
|
for j in range(len(ids)): |
|
|
if mask[j] == 1: |
|
|
token_dict[ids[j]] += 1 |
|
|
|
|
|
return token_dict |
|
|
|
|
|
|
|
|
def merge_dicts(dicts): |
|
|
"""合并多个 Counter 字典""" |
|
|
result = Counter() |
|
|
for d in dicts: |
|
|
result.update(d) |
|
|
return result |
|
|
|
|
|
|
|
|
class Model(nn.Module): |
|
|
def __init__(self, config, ds_config, training_config, load_head=False, load_emb=True, path=None, model_type='llama'): |
|
|
super().__init__() |
|
|
self.model_type = model_type |
|
|
|
|
|
|
|
|
self.train_config = training_config |
|
|
|
|
|
if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3: |
|
|
dschf = HfDeepSpeedConfig(ds_config) |
|
|
else: |
|
|
dschf = None |
|
|
self.midlayer = LlamaDecoderLayeremb(config) |
|
|
self.gradient_checkpointing = self.train_config["gradient_checkpointing"] |
|
|
self.padding_idx = config.pad_token_id |
|
|
self.vocab_size = config.vocab_size |
|
|
self.hidden_size = config.hidden_size |
|
|
self.draft_vocab_size = config.draft_vocab_size |
|
|
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
self.length = 6 |
|
|
|
|
|
|
|
|
if self.model_type == 'qwen3': |
|
|
self.target_model = Qwen3ForCausalLM.from_pretrained(path, torch_dtype=torch.float16) |
|
|
else: |
|
|
self.target_model = LlamaForCausalLM.from_pretrained(path, torch_dtype=torch.float16) |
|
|
|
|
|
self.target_model.eval() |
|
|
self.fc=nn.Linear(self.hidden_size*3, self.hidden_size, bias=False) |
|
|
for param in self.target_model.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
if not load_emb: |
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
|
|
|
|
|
else: |
|
|
|
|
|
from safetensors import safe_open |
|
|
import json |
|
|
import os |
|
|
try: |
|
|
with open(os.path.join(path, "model.safetensors.index.json"), "r") as f: |
|
|
index_json = json.loads(f.read()) |
|
|
emb_path = index_json["weight_map"]["model.embed_tokens.weight"] |
|
|
with safe_open(os.path.join(path, emb_path), |
|
|
framework="pt", |
|
|
device="cpu") as f: |
|
|
tensor_slice = f.get_slice("model.embed_tokens.weight") |
|
|
vocab_size, hidden_dim = tensor_slice.get_shape() |
|
|
tensor = tensor_slice[:, :hidden_dim].float() |
|
|
except: |
|
|
with open(os.path.join(path, "pytorch_model.bin.index.json"), "r") as f: |
|
|
index_json = json.loads(f.read()) |
|
|
emb_path = index_json["weight_map"]["model.embed_tokens.weight"] |
|
|
weights = torch.load(os.path.join(path, emb_path)) |
|
|
tensor = weights["model.embed_tokens.weight"].float() |
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, _weight=tensor) |
|
|
|
|
|
self.lm_head = nn.Linear(config.hidden_size, config.draft_vocab_size, bias=False) |
|
|
|
|
|
for param in self.embed_tokens.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
def scandata(self, datapath, tokenizerpath): |
|
|
N = self.draft_vocab_size |
|
|
|
|
|
|
|
|
cache_file = f"cache_{self.model_type}.pt" if self.model_type != 'llama' else "cache.pt" |
|
|
|
|
|
if not os.path.exists(cache_file): |
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizerpath) |
|
|
dataset = load_dataset('json', data_files=datapath) |
|
|
dataset = dataset['train'] |
|
|
|
|
|
original_columns1 = dataset.column_names |
|
|
num_proc = 1 |
|
|
|
|
|
|
|
|
if self.model_type == 'qwen3': |
|
|
sep = "<|im_end|>\n<|im_start|>assistant\n" |
|
|
sep2 = "<|im_end|>\n<|im_start|>user\n" |
|
|
else: |
|
|
sep = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" |
|
|
sep2 = "<|eot_id|><|start_header_id|>user<|end_header_id|>" |
|
|
|
|
|
def preprocess_function(examples): |
|
|
new_examples = { |
|
|
|
|
|
"input_ids": [], |
|
|
"loss_mask": [] |
|
|
} |
|
|
for i in range(len(examples['id'])): |
|
|
messages = [ |
|
|
{"role": "system", |
|
|
"content": "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."}, |
|
|
] |
|
|
convroles = ["user", "assistant"] |
|
|
roles = {"human": "user", "gpt": "assistant"} |
|
|
source = examples['conversations'][i] |
|
|
if not source: |
|
|
continue |
|
|
if roles[source[0]["from"]] != "user": |
|
|
|
|
|
source = source[1:] |
|
|
for j, sentence in enumerate(source): |
|
|
role = roles[sentence["from"]] |
|
|
assert role == convroles[j % 2], f"{i}" |
|
|
|
|
|
|
|
|
messages.append( |
|
|
{"role": role, "content": sentence["value"]} |
|
|
) |
|
|
conversation = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=False, |
|
|
) |
|
|
|
|
|
if not tokenizer.pad_token_id: |
|
|
tokenizer.pad_token_id = tokenizer.unk_token_id |
|
|
|
|
|
input_ids = tokenizer( |
|
|
conversation, |
|
|
return_tensors="pt", |
|
|
add_special_tokens=False, |
|
|
).input_ids[0] |
|
|
|
|
|
|
|
|
|
|
|
if len(input_ids) > self.train_config["max_len"]: |
|
|
continue |
|
|
loss_mask = torch.ones_like(input_ids) |
|
|
|
|
|
|
|
|
total_len = len(input_ids) |
|
|
|
|
|
turns = conversation.split(sep2) |
|
|
|
|
|
|
|
|
if len(turns) < 2: |
|
|
continue |
|
|
|
|
|
turns[1] = turns[0] + sep2 + turns[1] |
|
|
turns = turns[1:] |
|
|
|
|
|
cur_len = 1 |
|
|
loss_mask[:cur_len] = 0 |
|
|
for i, turn in enumerate(turns): |
|
|
if turn == "": |
|
|
break |
|
|
turn_len = len(tokenizer(turn).input_ids) |
|
|
|
|
|
parts = turn.split(sep) |
|
|
if len(parts) != 2: |
|
|
break |
|
|
parts[0] += sep |
|
|
|
|
|
instruction_len = len(tokenizer(parts[0]).input_ids) - 1 |
|
|
|
|
|
|
|
|
if i == 0: |
|
|
loss_mask[cur_len: cur_len + instruction_len - 2] = 0 |
|
|
else: |
|
|
loss_mask[cur_len - 3: cur_len + instruction_len + 1] = 0 |
|
|
cur_len += turn_len |
|
|
if i != 0: |
|
|
cur_len += 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss_mask[cur_len:] = 0 |
|
|
|
|
|
|
|
|
new_examples["input_ids"].append(input_ids[None, :]) |
|
|
new_examples["loss_mask"].append(loss_mask[None, :]) |
|
|
|
|
|
return new_examples |
|
|
|
|
|
dataset = dataset.map( |
|
|
preprocess_function, |
|
|
batched=True, |
|
|
num_proc=num_proc, |
|
|
remove_columns=original_columns1, |
|
|
load_from_cache_file=False |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chunks = [dataset[i:i + len(dataset)] for i in range(0, len(dataset), len(dataset))] |
|
|
results = [process_data(chunk) for chunk in chunks] |
|
|
|
|
|
|
|
|
token_dict = merge_dicts(results) |
|
|
|
|
|
|
|
|
total_frequency = sum(token_dict.values()) |
|
|
top_N = token_dict.most_common(N) |
|
|
top_N_frequency_sum = sum(freq for key, freq in top_N) |
|
|
top_N_ratio = top_N_frequency_sum / total_frequency |
|
|
print(f"top {N} token frequency ratio: {top_N_ratio:.2%}") |
|
|
used_tokens = [key for key, freq in top_N] |
|
|
used_tokens.sort() |
|
|
d2t = [used_tokens[i] - i for i in range(len(used_tokens))] |
|
|
t2d = [i in used_tokens for i in range(self.vocab_size)] |
|
|
d2t = torch.tensor(d2t) |
|
|
t2d = torch.tensor(t2d) |
|
|
cache = { |
|
|
"d2t": d2t, |
|
|
"t2d": t2d |
|
|
} |
|
|
torch.save(cache, cache_file) |
|
|
else: |
|
|
cache = torch.load(cache_file) |
|
|
d2t = cache["d2t"] |
|
|
t2d = cache["t2d"] |
|
|
self.register_buffer("d2t", d2t) |
|
|
self.register_buffer("t2d", t2d) |
|
|
self.l1smooth = nn.SmoothL1Loss(reduction="none") |
|
|
|
|
|
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): |
|
|
|
|
|
|
|
|
combined_attention_mask = None |
|
|
if input_shape[-1] > 1: |
|
|
combined_attention_mask = _make_causal_mask( |
|
|
input_shape, |
|
|
inputs_embeds.dtype, |
|
|
device=inputs_embeds.device, |
|
|
past_key_values_length=past_key_values_length, |
|
|
) |
|
|
|
|
|
if attention_mask is not None: |
|
|
|
|
|
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( |
|
|
inputs_embeds.device |
|
|
) |
|
|
combined_attention_mask = ( |
|
|
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask |
|
|
) |
|
|
|
|
|
return combined_attention_mask |
|
|
|
|
|
@torch.no_grad() |
|
|
def dataprepare(self, input_ids, attention_mask, loss_mask): |
|
|
device = input_ids.device |
|
|
outs = self.target_model(input_ids=input_ids, attention_mask=attention_mask) |
|
|
hidden_states0 = outs.hidden_states[0] |
|
|
hidden_states1 = outs.hidden_states[1] |
|
|
hidden_states2 = outs.hidden_states[2] |
|
|
hidden_states=torch.cat((hidden_states0,hidden_states1,hidden_states2),dim=-1) |
|
|
|
|
|
target = outs.logits |
|
|
target = padding(target, left=False) |
|
|
input_ids = padding(input_ids, left=False) |
|
|
|
|
|
if target is not None: |
|
|
target = target.to(device) |
|
|
loss_mask = loss_mask[..., None] |
|
|
loss_mask = loss_mask.to(device) |
|
|
|
|
|
return hidden_states, target, loss_mask, input_ids |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
|
|
|
input_ids, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
loss_mask: Optional[torch.Tensor] = None, |
|
|
|
|
|
): |
|
|
hidden_states, target, loss_mask, input_ids = self.dataprepare(input_ids, attention_mask, loss_mask) |
|
|
|
|
|
batch_size, seq_length, _ = hidden_states.shape |
|
|
seq_length_with_past = seq_length |
|
|
past_key_values_length = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.training and self.gradient_checkpointing and not hidden_states.requires_grad: |
|
|
hidden_states.requires_grad = True |
|
|
|
|
|
hidden_states=self.fc(hidden_states) |
|
|
|
|
|
if past_key_values is not None: |
|
|
past_key_values_length = past_key_values[0][0].shape[2] |
|
|
seq_length_with_past = seq_length_with_past + past_key_values_length |
|
|
if position_ids is None: |
|
|
device = hidden_states.device |
|
|
position_ids = torch.arange( |
|
|
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device |
|
|
) |
|
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length) |
|
|
else: |
|
|
position_ids = position_ids.view(-1, seq_length).long() |
|
|
|
|
|
if attention_mask is None: |
|
|
attention_mask = torch.ones( |
|
|
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device |
|
|
) |
|
|
attention_mask = self._prepare_decoder_attention_mask( |
|
|
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length |
|
|
) |
|
|
|
|
|
if self.gradient_checkpointing and self.training: |
|
|
if use_cache: |
|
|
use_cache = False |
|
|
|
|
|
plosses = [] |
|
|
vlosses = [] |
|
|
acces = [] |
|
|
cache_hidden = [[], []] |
|
|
|
|
|
for idx in range(self.length): |
|
|
last = idx == self.length - 1 |
|
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
if self.training and self.gradient_checkpointing and not inputs_embeds.requires_grad: |
|
|
inputs_embeds.requires_grad = True |
|
|
inputs_embeds = inputs_embeds.to(hidden_states.dtype) |
|
|
|
|
|
if self.gradient_checkpointing and self.training: |
|
|
|
|
|
def create_custom_forward(module): |
|
|
def custom_forward(*inputs): |
|
|
|
|
|
return module(*inputs, None, output_attentions) |
|
|
|
|
|
return custom_forward |
|
|
|
|
|
layer_outputs, cache_hidden = torch.utils.checkpoint.checkpoint( |
|
|
create_custom_forward(self.midlayer), |
|
|
inputs_embeds, |
|
|
hidden_states, |
|
|
cache_hidden, |
|
|
attention_mask, |
|
|
position_ids, |
|
|
) |
|
|
else: |
|
|
|
|
|
layer_outputs, cache_hidden = self.midlayer( |
|
|
input_emb=inputs_embeds, |
|
|
hidden_states=hidden_states, |
|
|
cache_hidden=cache_hidden, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_value=None, |
|
|
output_attentions=output_attentions, |
|
|
use_cache=True, |
|
|
) |
|
|
|
|
|
hidden_states_out = layer_outputs[0] |
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
target_head = target |
|
|
target_max_token = target_head.argmax(-1) |
|
|
|
|
|
self.t2d = self.t2d.to(target_max_token.device) |
|
|
target_mask = self.t2d[target_max_token] |
|
|
target_mask = target_mask[..., None].int() |
|
|
position_mask = target_mask * loss_mask |
|
|
target_head = target_head[..., self.t2d] |
|
|
target_head = target_head.float() |
|
|
target_p = nn.Softmax(dim=2)(target_head) |
|
|
target_p = target_p.detach() |
|
|
|
|
|
|
|
|
|
|
|
hidden_states = hidden_states_out |
|
|
|
|
|
hidden_states_out = self.norm(hidden_states_out) |
|
|
|
|
|
logits = self.lm_head(hidden_states_out) |
|
|
logits = logits.float() |
|
|
out_logp = nn.LogSoftmax(dim=2)(logits) |
|
|
plogp = target_p * out_logp |
|
|
loss = -torch.sum(position_mask * plogp, 2).mean() |
|
|
plosses.append(loss) |
|
|
with torch.no_grad(): |
|
|
|
|
|
acces.append(((logits.argmax(-1) == target_p.argmax(-1)) * position_mask.squeeze(-1)).sum().item() / ( |
|
|
position_mask.sum().item() + 1e-6)) |
|
|
|
|
|
if not last: |
|
|
input_ids = padding(input_ids, left=False) |
|
|
target = padding(target, left=False) |
|
|
loss_mask = padding(loss_mask, left=False) |
|
|
|
|
|
|
|
|
|
|
|
return plosses, vlosses, acces |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|