maotao / model /model_ribo.py
julse's picture
upload AA2CDS
4707555 verified
import math
import struct
import inspect
import time
from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List
import numpy as np
import torch
from fairseq import utils
import torch.nn.functional as F
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x): # [16, 1205, 256]
return self.weight * (x.float() * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)).type_as(x)
def precompute_pos_cis(dim: int, end: int, theta: float = 1e4):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return pos_cis
def apply_rotary_emb(xq, xk, pos_cis):
def unite_shape(pos_cis, x):
ndim = x.ndim
assert 0 <= 1 < ndim
# print('pos_cis',pos_cis.shape,(x.shape[1], x.shape[-1]))
# assert pos_cis.shape == (x.shape[1], x.shape[-1]) # pos_cis torch.Size([1205, 16]) (1207, 16)
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return pos_cis.view(*shape)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
pos_cis = unite_shape(pos_cis, xq_)
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
class Attention(nn.Module):
def __init__(self, args: LMConfig):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
assert args.n_heads % self.n_kv_heads == 0
self.n_local_heads = args.n_heads
self.n_local_kv_heads = self.n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.twod_proj = nn.Linear(1, 1)
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
if not args.flash_attn and not self.flash:print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask, persistent=False)
def forward(self,
x: torch.Tensor,
pos_cis: torch.Tensor,
twod_tokens: Optional[torch.Tensor] = None, # 新增参数
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache=False,is_causal=False):
bsz, seq_len, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
# kv_cache实现
if past_key_value is not None:
xk = torch.cat([past_key_value[0], xk], dim=1)
xv = torch.cat([past_key_value[1], xv], dim=1)
past_kv = (xk, xv) if use_cache else None
xq, xk, xv = (
xq.transpose(1, 2),
repeat_kv(xk, self.n_rep).transpose(1, 2),
repeat_kv(xv, self.n_rep).transpose(1, 2)
)
if twod_tokens is not None:
twod_tokens = twod_tokens.permute(0,2,3,1)#.contiguous()
twod_tokens = self.twod_proj(twod_tokens) # [B,size,size,1] -> [B,size,size,12]
twod_bias = twod_tokens.permute(0,3,1,2)#.contiguous()
else:
twod_bias = None
if self.flash and seq_len != 1: # flash attention, 如果序列长度为 1,可能不需要使用 Flash Attention,或者此时使用 Flash Attention 没有意义。
dropout_p = self.dropout if self.training else 0.0
output = F.scaled_dot_product_attention(
xq, xk, xv,
attn_mask=twod_bias,
dropout_p=dropout_p,
is_causal=is_causal # true每个位置只能关注到它之前的位置,从而保证模型的因果性。is_causal=True 可以简化代码实现,不需要手动创建和应用因果掩码
)
else: # manual attention
# twod_bias = twod_bias.reshape(bsz * self.n_local_heads, seq_len, seq_len) # self.n_local_heads 有点问题
twod_bias = twod_bias.repeat(1, self.n_local_heads, 1, 1)
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim) + twod_bias
scores += self.mask[:, :, :seq_len, :seq_len]
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = scores @ xv
# print('self.flash and seq_len != 1',self.flash and seq_len != 1,output.shape)
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
output = self.resid_dropout(self.wo(output))
return output, past_kv
class FeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
if config.hidden_dim is None:
hidden_dim = 4 * config.dim
hidden_dim = int(2 * hidden_dim / 3)
config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class MoEGate(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts
self.scoring_func = config.scoring_func
self.alpha = config.aux_loss_alpha
self.seq_aux = config.seq_aux
self.norm_topk_prob = config.norm_topk_prob
self.gating_dim = config.dim
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
self.reset_parameters()
def reset_parameters(self) -> None:
import torch.nn.init as init
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
hidden_states = hidden_states.view(-1, h)
logits = F.linear(hidden_states, self.weight, None)
if self.scoring_func == 'softmax':
scores = logits.softmax(dim=-1)
else:
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
if self.training and self.alpha > 0.0:
scores_for_aux = scores
aux_topk = self.top_k
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
if self.seq_aux:
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
ce.scatter_add_(1, topk_idx_for_aux_loss,
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
seq_len * aux_topk / self.n_routed_experts)
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
else:
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
ce = mask_ce.float().mean(0)
Pi = scores_for_aux.mean(0)
fi = ce * self.n_routed_experts
aux_loss = (Pi * fi).sum() * self.alpha
else:
aux_loss = 0
return topk_idx, topk_weight, aux_loss
class NonLinearHead(nn.Module):
"""Head for simple classification tasks."""
def __init__(
self,
input_dim,
out_dim,
activation_fn,
hidden=None,
):
super().__init__()
hidden = input_dim if not hidden else hidden
self.linear1 = nn.Linear(input_dim, hidden)
self.linear2 = nn.Linear(hidden, out_dim)
self.activation_fn = utils.get_activation_fn(activation_fn)
def forward(self, x):
x = self.linear1(x)
x = self.activation_fn(x)
x = self.linear2(x)
return x
class MOEFeedForward(nn.Module):
def __init__(self, config: LMConfig):
super().__init__()
self.config = config
self.experts = nn.ModuleList([
FeedForward(config)
for _ in range(config.n_routed_experts)
])
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
self.shared_experts = FeedForward(config)
def forward(self, x):
identity = x
orig_shape = x.shape
bsz, seq_len, _ = x.shape
# 使用门控机制选择专家
topk_idx, topk_weight, aux_loss = self.gate(x)
x = x.view(-1, x.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if self.training:
# 训练模式下,重复输入数据
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
y = torch.empty_like(x, dtype=torch.float16)
for i, expert in enumerate(self.experts):
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
y = y.view(*orig_shape)
else:
# 推理模式下,只选择最优专家
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(identity)
self.aux_loss = aux_loss
return y
@torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = torch.zeros_like(x)
idxs = flat_expert_indices.argsort()
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
token_idxs = idxs // self.config.num_experts_per_tok
# 例如当tokens_per_expert=[6, 15, 20, 26, 33, 38, 46, 52]
# 当token_idxs=[3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...]
# 意味着当token_idxs[:6] -> [3, 7, 19, 21, 24, 25, 4]位置的token都由专家0处理,token_idxs[6:15]位置的token都由专家1处理......
for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
if start_idx == end_idx:
continue
expert = self.experts[i]
exp_token_idx = token_idxs[start_idx:end_idx]
expert_tokens = x[exp_token_idx]
expert_out = expert(expert_tokens).to(expert_cache.dtype)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
# 使用 scatter_add_ 进行 sum 操作
expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
return expert_cache
class MiniMindBlock(nn.Module):
def __init__(self, layer_id: int, config: LMConfig):
super().__init__()
self.n_heads = config.n_heads
self.dim = config.dim
self.head_dim = config.dim // config.n_heads
self.attention = Attention(config)
self.layer_id = layer_id
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
def forward(self, x, pos_cis, twod_tokens=None, past_key_value=None, use_cache=False): # 新增参数
# print(f'forword twod_tokens: {twod_tokens.shape}')
h_attn, past_kv = self.attention(
self.attention_norm(x),
pos_cis,
twod_tokens=twod_tokens, # 新增参数
past_key_value=past_key_value,
use_cache=use_cache
)
h = x + h_attn
out = h + self.feed_forward(self.ffn_norm(h))
return out, past_kv
class MiniMindLM(PreTrainedModel): # student
config_class = LMConfig
def __init__(self, params: LMConfig = None):
self.params = params or LMConfig()
super().__init__(self.params)
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim,padding_idx=params.padding_idx)
self.dropout = nn.Dropout(params.dropout)
self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)])
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.logit_dim, bias=False)
# self.output_aa = nn.Linear(params.dim, 21, bias=False) # amino acid
# self.pool1d = nn.MaxPool1d(kernel_size=3,stride=3,padding=0)
self.tok_embeddings.weight = self.output.weight
self.register_buffer("pos_cis", precompute_pos_cis(params.dim // params.n_heads, params.max_seq_len,
theta=params.rope_theta), persistent=False)
self.OUT = CausalLMOutputWithPast()
def forward(self,
input_ids: Optional[torch.Tensor] = None,
twod_tokens: Optional[torch.Tensor] = None, # 新增参数
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
**args):
past_key_values = past_key_values or [None] * len(self.layers)
start_pos = args.get('start_pos', 0)
twod_tokens = twod_tokens.to(torch.float32)
h = self.dropout(self.tok_embeddings(input_ids)) # set(input_ids.numpy().reshape(-1)), {0, 1, 2, 3, 4, 5, 6, 7, 14, 16, 18, 19, 24}
seq_mask = input_ids == 1# padding note
seq_mask.unsqueeze_(-1)
h = h.masked_fill_(seq_mask, 0)
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
past_kvs = []
for l, layer in enumerate(self.layers):
h, past_kv = layer(
h, pos_cis,
twod_tokens=twod_tokens,
past_key_value=past_key_values[l],
use_cache=use_cache
)
h = h.masked_fill_(seq_mask, 0)
past_kvs.append(past_kv)
h = self.norm(h)
logits = self.output(h)
# h = self.output_aa(h)
# h = h.permute(0, 2, 1)
# h = self.pool1d(h)
# logits_aa = h.permute(0, 2, 1)
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
if not h.requires_grad:
# 计算非 padding 元素的总和
sum_h = torch.sum(h * ~seq_mask, dim=(1, 2))
# 计算非 padding 元素的数量
count_h = torch.sum(~seq_mask, dim=(1, 2))
# 计算均值
mean_h = sum_h / count_h
# 处理特殊情况,如果某个样本的非 padding 元素数量为 0,将该样本的均值设为 0
mean_h[count_h == 0] = 0
# 将均值 reshape 为 (-1, 1)
zero_shot = mean_h.reshape(-1, 1)
# print(zero_shot.shape,zero_shot)
else:
zero_shot = None
self.OUT.__setitem__('logits', logits)
# self.OUT.__setitem__('logits_aa', logits_aa)
self.OUT.__setitem__('aux_loss', aux_loss)
self.OUT.__setitem__('past_key_values', past_kvs)
self.OUT.__setitem__('embeddings', h)
self.OUT.__setitem__('zero_shot', zero_shot) # 零样本学习的结果
# print('embeddings',h.shape)
return self.OUT
@torch.inference_mode()
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
stream=False, rp=1., use_cache=True, pad_token_id=0, **args):
# 流式生成
if stream:
return self._generate_stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache)
# 直接生成
generated = []
for i in range(input_ids.size(0)):
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
out = self._generate_stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache)
tokens_list = [tokens[:, -1:] for tokens in out]
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
full_sequence = torch.cat([non_pad, gen], dim=-1)
generated.append(full_sequence)
max_length = max(seq.size(1) for seq in generated)
generated = [
torch.cat(
[seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],
dim=-1)
for seq in generated
]
return torch.cat(generated, dim=0)
def _generate_stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args):
start, first_seq, past_kvs = input_ids.shape[1], True, None
while input_ids.shape[1] < max_new_tokens - 1:
if first_seq or not use_cache:
out, first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache), False
else:
out = self(input_ids[:, -1:], past_key_values=past_kvs, use_cache=use_cache,
start_pos=input_ids.shape[1] - 1)
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
logits[:, list(set(input_ids.tolist()[0]))] /= rp
logits /= (temperature + 1e-9)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = -float('Inf')
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
yield input_ids[:, start:]
if input_ids_next.item() == eos_token_id:
break