YModel1.1 / ymodel1_1.py
SnifferCaptain's picture
Init YModel
8b30359 verified
import math
import torch
import torch.nn as nn
from typing import Optional, Tuple, Union, List
from transformers import PreTrainedModel, GenerationMixin
from transformers.activations import ACT2FN
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.configuration_utils import PretrainedConfig
class YConfig1_1(PretrainedConfig):
model_type = "ynet"
def __init__(
self,
dropout: float = 0.1,
bos_token_id: int = 1,
eos_token_id: int = 2,
hidden_act: str = 'gelu_pytorch_tanh',
exp: float = 3.0,
ffn_shared: int = 3,
hidden_size: int = 512,
intermediate_size: int = None,
max_position_embeddings: int = 8192,
num_heads: int = 8,
num_layers: int = 9,
pe_dim: int = 64,
head_dim: int = 64,
groups: int = 4,
vocab_size: int = 6400,
rms_norm_eps: float = 1e-7,
rope_theta: int = 5e4,
flash_attn: bool = True,
self_distill: bool = True,
**kwargs
):
super().__init__(**kwargs)
self.dropout = dropout
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.hidden_act = hidden_act
self.exp = exp # ffn 扩张倍率
self.ffn_shared = ffn_shared # ffn up & down权重共享层数
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.max_position_embeddings = max_position_embeddings
self.num_heads = num_heads # q头数
self.num_layers = num_layers # 层数
self.pe_dim = pe_dim # 位置嵌入头数
self.head_dim = head_dim # 头维度
self.groups = groups # GQA每个分组的头数
self.vocab_size = vocab_size
self.rms_norm_eps = rms_norm_eps
self.rope_theta = rope_theta
self.flash_attn = flash_attn
self.self_distill = self_distill
def scale_lvl(self, lvl:int=0):
if lvl == 0:
# normal settings [80.27m]
self.exp = 3.0
self.ffn_shared = 3
self.hidden_size = 512
self.num_heads = 12
self.num_layers = 27
self.pe_dim = 96
self.head_dim = 64
self.groups = 6
elif lvl == -1:
# small -1 [24m]
self.exp = 3.0
self.ffn_shared = 3
self.hidden_size = 512
self.num_heads = 8
self.num_layers = 12
self.pe_dim = 64
self.head_dim = 64
self.groups = 8
elif lvl == -2:
# small -2 [12m]
self.exp = 2.0
self.ffn_shared = 4
self.hidden_size = 512
self.num_heads = 7
self.num_layers = 8
self.pe_dim = 48
self.head_dim = 48
self.groups = 6
elif lvl == -3:
# small -3 [6m]
self.exp = 2.0
self.ffn_shared = 3
self.hidden_size = 384
self.num_heads = 7
self.num_layers = 6
self.pe_dim = 48
self.head_dim = 32
self.groups = 6
######## large #######
elif lvl == 1:
# large +1 [0.2b]
self.exp = 2.0
self.ffn_shared = 3
self.hidden_size = 768
self.num_heads = 12
self.num_layers = 24
self.pe_dim = 96
self.head_dim = 64
self.groups = 6
elif lvl == 2:
# large +2 [0.6b]
self.exp = 3.0
self.ffn_shared = 3
self.hidden_size = 1344
self.num_heads = 25
self.num_layers = 24
self.pe_dim = 192
self.head_dim = 96
self.groups = 7
else:
raise ValueError(f"Invalid level: {lvl}")
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float())
output = output * self.weight.float()
return output.type_as(x)
def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), theta: float = 5e4):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1)
freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1)
return freqs_cos, freqs_sin
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0):
def rotate_half(x):
return torch.cat((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), dim=-1)
q_embed = (q * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(q) * sin.unsqueeze(unsqueeze_dim))
k_embed = (k * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(k) * sin.unsqueeze(unsqueeze_dim))
return q_embed, k_embed
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
b, h, l, ch = x.shape
if n_rep == 1:
return x
return (
x[:, :, None, :, :]
.expand(b, h, n_rep, l, ch)
.reshape(b, h * n_rep, l, ch)
)
class PEGA(nn.Module):
"""
位置编码门控注意力
"""
def __init__(self, config: YConfig1_1):
super().__init__()
self.dropout = config.dropout # dropout rate
self.hidden_size = config.hidden_size # 输入通道大小
self.num_heads = config.num_heads # 总注意力头数
self.pe_dim = config.pe_dim # 位置嵌入维度数
self.head_dim = config.head_dim # 每个头的维度
self.groups = config.groups # GQA头数
self.hidden_kv_dim = int(self.head_dim * self.num_heads // self.groups)
self.gate_act = ACT2FN[config.hidden_act]
self.delta_kv_only = False
assert self.num_heads % self.groups == 0, "num_heads must be divisible by groups"
# self.qpe = nn.Linear(self.hidden_size, self.pe_dim, bias=False)
# self.kpe = nn.Linear(self.hidden_size, self.pe_dim, bias=False)
# self.q = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
# self.kv = nn.Linear(self.hidden_size, self.hidden_kv_dim, bias=False)
# equals to above
self.qkv_list = [self.pe_dim, self.pe_dim, self.num_heads * self.head_dim, self.hidden_kv_dim]
self.qkv = nn.Linear(self.hidden_size, sum(self.qkv_list), bias=False)
self.o = nn.Linear(self.num_heads * self.hidden_kv_dim, self.hidden_size, bias=False)
self.gate = nn.Linear(self.hidden_kv_dim, self.num_heads * self.hidden_kv_dim, bias=False)
self.rsqrt_dim = 1.0 / math.sqrt(self.head_dim)
def forward(
self,
x: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
b, l, _ = x.shape
cos, sin = position_embeddings # [L, head_dim]
# qpe = self.qpe(x) # [b, l, pe]
# kpe = self.kpe(x) # [b, l, pe]
# q = self.q(x) # [b, l, nope * hc]
# kv = self.kv(x) # [b, l, ckv]
qkv = self.qkv(x)
qpe, kpe, q, kv = torch.split(qkv, self.qkv_list, dim=-1)
# 应用 RoPE
qpe, kpe = apply_rotary_pos_emb(
qpe,
kpe,
cos[:l],
sin[:l],
)
deltakv = None
if self.delta_kv_only:
# 仅返回 delta kv
deltakv = (kpe, kv)
# kv_cache实现
if past_key_value is not None:
kpe = torch.cat([past_key_value[0], kpe], dim=1)
kv = torch.cat([past_key_value[1], kv], dim=1)
past_kv = (kpe, kv) if use_cache else None
_, l_all, _ = kv.shape
dropout_p = self.dropout if self.training else 0.0
attn_mask = None
if attention_mask is not None:
attn_mask = attention_mask.view(b, 1, 1, -1).expand(b, 1, l, -1)
attn_mask = attn_mask.bool() if attention_mask is not None else None
qpe = qpe.reshape(b, l, 1, self.pe_dim).permute(0, 2, 1, 3) # [b, pe, l, hc]
kpe = kpe.reshape(b, l_all, 1, self.pe_dim).permute(0, 2, 1, 3) # [b, pe, l_all, hc]
q = q.reshape(b, l, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # [b, nope, l, hc]
nopek = kv.reshape(b, l_all, self.num_heads // self.groups, self.head_dim).permute(0, 2, 1, 3) # [b, g, l_all, hc]
kv = kv.reshape(b, l_all, 1, self.hidden_kv_dim).permute(0, 2, 1, 3) # [b, 1, l_all, hc]
if self.training:
peo = nn.functional.scaled_dot_product_attention(
qpe, kpe, kv,
attn_mask=attn_mask, dropout_p=dropout_p if self.training else 0.0, is_causal=True
)
nopeo = nn.functional.scaled_dot_product_attention(
q, repeat_kv(nopek, self.groups), repeat_kv(kv, self.num_heads),
attn_mask=attn_mask, dropout_p=dropout_p if self.training else 0.0, is_causal=True
)
else:
# peo = nn.functional.scaled_dot_product_attention(
# qpe, kpe, kv,
# attn_mask=attn_mask, dropout_p=dropout_p if self.training else 0.0, is_causal=l != 1
# )
# nopeo = nn.functional.scaled_dot_product_attention(
# q, repeat_kv(nopek, self.groups), repeat_kv(kv, self.num_heads),
# attn_mask=attn_mask, dropout_p=dropout_p if self.training else 0.0, is_causal=l != 1
# )
peo = self.sdpa_math(qpe, kpe, kv, attn_mask, 0.0)
nopeo = self.sdpa_math(q, repeat_kv(nopek, self.groups), repeat_kv(kv, self.num_heads), attn_mask, 0.0)
peo = peo.permute(0, 2, 1, 3).reshape(b, l, -1)
nopeo = nopeo.permute(0, 2, 1, 3).reshape(b, l, -1)
gate = self.gate_act(self.gate(peo))
out = nopeo * gate
out = self.o(out)
out = nn.functional.dropout(out, p=self.dropout, training=self.training)
return out, (deltakv if self.delta_kv_only else past_kv)
def sdpa_math(self, q:torch.Tensor, k:torch.Tensor, v:torch.Tensor, attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0) -> (torch.Tensor, torch.Tensor):
b, h, l, c = q.shape
scores = (q @ k.transpose(-2, -1)) * self.rsqrt_dim
casual_mask = torch.triu(
torch.full((l, l), float("-inf"), device=scores.device),
diagonal=1
).unsqueeze(0).unsqueeze(0)# [1, 1, l, l]
# 在左侧 zero pad 到 scores 的形状 [1, 1, l, l_all]
casual_mask = nn.functional.pad(casual_mask, (scores.shape[-1] - l, 0), "constant", 0.0)# [1, 1, l, l_all]
scores += casual_mask
if attn_mask is not None:
attn_mask = (1.0 - attn_mask.type_as(scores)) * -1e9
scores = scores + attn_mask
scores = nn.functional.softmax(scores.float(), dim=-1).type_as(q)
scores = nn.functional.dropout(scores, p=dropout_p, training=self.training)# [b, h, l, l]
output = scores @ v
return output
def use_delta_kv_only(self, enable:bool=True):
# 仅返回 delta kv,减少内存开销
self.delta_kv_only = enable
class YFFN(nn.Module):
"""
shared up & down GeGLU, LoE (Lack of Expert) arc
"""
def __init__(self, config: YConfig1_1):
super().__init__()
self.act = ACT2FN[config.hidden_act]
self.channels = config.hidden_size
self.exp = config.exp
self.c_up = int(self.channels * self.exp)
self.ffn_shared = config.ffn_shared
self.up = nn.Linear(self.channels, self.c_up, bias=False)
self.down = nn.Linear(self.c_up, self.channels, bias=False)
self.gates = nn.ModuleList([
nn.Linear(self.channels, self.c_up, bias=False) for _ in range(self.ffn_shared)
])
def forward(self, x:torch.Tensor, index:int, up_res:torch.Tensor=None) -> Tuple[torch.Tensor, torch.Tensor]:
up = self.up(x)
if up_res is not None:
up += up_res
gate = self.gates[index](x)
gate = self.act(gate)
up *= gate
x = self.down(up)
return x, up
class YBlock(nn.Module):
"""
Groups of Transformer layers with shared FFN
num layers is ffn_shared
"""
def __init__(self, config: YConfig1_1):
super().__init__()
self.attentions = nn.ModuleList([PEGA(config) for _ in range(config.ffn_shared)])
self.ffn = YFFN(config)
self.attn_norms = nn.ModuleList([
RMSNorm(config.hidden_size, eps=config.rms_norm_eps) for _ in range(config.ffn_shared)
])
self.ffn_norms = nn.ModuleList([
RMSNorm(config.hidden_size, eps=config.rms_norm_eps) for _ in range(config.ffn_shared)
])
self.use_self_distill = config.self_distill
def forward(self,
x: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,# ffn_shard * kv cache
use_cache: bool = False,
attention_mask: Optional[torch.Tensor] = None
):
b, l, _ = x.shape
kv_outs = []
ups = None
cos_loss = None
for i, (layer, kv_cache) in enumerate(zip(self.attentions, past_key_values)):
x0 = x
res = x
x = self.attn_norms[i](x)
x, kv_out = layer(
x = x,
position_embeddings=position_embeddings,
past_key_value=kv_cache,
attention_mask=attention_mask,
use_cache=use_cache
)
x += res
res = x
x = self.ffn_norms[i](x)
x, ups = self.ffn(x, i, ups)
x += res
kv_outs.append(kv_out)
if self.training and self.use_self_distill:
xd = x.detach()
# cosine loss
c_loss = 1.0 - nn.functional.cosine_similarity(x0, xd, dim=-1).mean()
cos_loss = c_loss + cos_loss if cos_loss is not None else c_loss
return x, kv_outs, cos_loss
def delta_kv_only(self, delta_kv:bool=True):
for i in range(len(self.attentions)):
self.attentions[i].use_delta_kv_only(delta_kv)
class YModel(nn.Module):
def __init__(self, config: YConfig1_1):
super().__init__()
self.vocab_size = config.vocab_size
self.num_layers = config.num_layers
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.dropout = config.dropout
self.ffn_shared = config.ffn_shared
assert self.num_layers % self.ffn_shared == 0, "num_layers must be divisible by ffn_shared"
self.blks = nn.ModuleList([
YBlock(config) for _ in range(self.num_layers // self.ffn_shared)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
freqs_cos, freqs_sin = precompute_freqs_cis(dim=config.pe_dim,
end=config.max_position_embeddings, theta=config.rope_theta)
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
def forward(self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
**kwargs
):
batch_size, seq_length = input_ids.shape
past_key_values = past_key_values or [None] * self.num_layers
start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0
x = self.embed_tokens(input_ids)
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
position_embeddings = (
self.freqs_cos[start_pos:start_pos + seq_length],
self.freqs_sin[start_pos:start_pos + seq_length]
)
presents = []
cos_loss = None
for layer_idx, block in enumerate(self.blks):
past_key_value = past_key_values[self.ffn_shared * layer_idx: self.ffn_shared * (layer_idx + 1)]
x, present, c_loss = block(
x = x,
position_embeddings = position_embeddings,
past_key_values=past_key_value,
use_cache=use_cache,
attention_mask=attention_mask
)
presents.extend(present)
cos_loss = c_loss + cos_loss if cos_loss is not None else c_loss
x = self.norm(x)
return x, presents, (cos_loss / self.num_layers if cos_loss is not None else None)
def delta_kv_only(self, delta_kv:bool=True):
for i in range(len(self.blks)):
self.blks[i].delta_kv_only(delta_kv)
class YForCausalLM1_1(PreTrainedModel, GenerationMixin):
config_class = YConfig1_1
def __init__(self, config: YConfig1_1 = None):
self.config = config or YConfig1_1()
super().__init__(self.config)
self.model = YModel(self.config)
self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
self.model.embed_tokens.weight = self.lm_head.weight
self.OUT = CausalLMOutputWithPast()
def forward(self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
logits_to_keep: Union[int, torch.Tensor] = 0,
**args):
h, past_kvs, cos_loss = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
**args
)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(h[:, slice_indices, :])
self.OUT.__setitem__('last_hidden_state', h)
self.OUT.__setitem__('logits', logits)
self.OUT.__setitem__('aux_loss', 0.0)
self.OUT.__setitem__('past_key_values', past_kvs)
self.OUT.__setitem__('dist_loss', cos_loss)
return self.OUT
def delta_kv_only(self, delta_kv:bool=True):
self.model.delta_kv_only(delta_kv)