import math import struct import inspect import time from .LMConfig import LMConfig from typing import Any, Optional, Tuple, List import numpy as np import torch from fairseq import utils import torch.nn.functional as F from torch import nn from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): # [16, 1205, 256] return self.weight * (x.float() * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)).type_as(x) def precompute_pos_cis(dim: int, end: int, theta: float = 1e4): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) # type: ignore freqs = torch.outer(t, freqs).float() # type: ignore pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 return pos_cis def apply_rotary_emb(xq, xk, pos_cis): def unite_shape(pos_cis, x): ndim = x.ndim assert 0 <= 1 < ndim # print('pos_cis',pos_cis.shape,(x.shape[1], x.shape[-1])) # assert pos_cis.shape == (x.shape[1], x.shape[-1]) # pos_cis torch.Size([1205, 16]) (1207, 16) shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] return pos_cis.view(*shape) xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) pos_cis = unite_shape(pos_cis, xq_) xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3) xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" bs, slen, n_kv_heads, head_dim = x.shape if n_rep == 1: return x return ( x[:, :, :, None, :] .expand(bs, slen, n_kv_heads, n_rep, head_dim) .reshape(bs, slen, n_kv_heads * n_rep, head_dim) ) class Attention(nn.Module): def __init__(self, args: LMConfig): super().__init__() self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads assert args.n_heads % self.n_kv_heads == 0 self.n_local_heads = args.n_heads self.n_local_kv_heads = self.n_kv_heads self.n_rep = self.n_local_heads // self.n_local_kv_heads self.head_dim = args.dim // args.n_heads self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) self.twod_proj = nn.Linear(1, 1) self.attn_dropout = nn.Dropout(args.dropout) self.resid_dropout = nn.Dropout(args.dropout) self.dropout = args.dropout self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn if not args.flash_attn and not self.flash:print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf")) mask = torch.triu(mask, diagonal=1) self.register_buffer("mask", mask, persistent=False) def forward(self, x: torch.Tensor, pos_cis: torch.Tensor, twod_tokens: Optional[torch.Tensor] = None, # 新增参数 past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache=False,is_causal=False): bsz, seq_len, _ = x.shape xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim) xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) xq, xk = apply_rotary_emb(xq, xk, pos_cis) # kv_cache实现 if past_key_value is not None: xk = torch.cat([past_key_value[0], xk], dim=1) xv = torch.cat([past_key_value[1], xv], dim=1) past_kv = (xk, xv) if use_cache else None xq, xk, xv = ( xq.transpose(1, 2), repeat_kv(xk, self.n_rep).transpose(1, 2), repeat_kv(xv, self.n_rep).transpose(1, 2) ) if twod_tokens is not None: twod_tokens = twod_tokens.permute(0,2,3,1)#.contiguous() twod_tokens = self.twod_proj(twod_tokens) # [B,size,size,1] -> [B,size,size,12] twod_bias = twod_tokens.permute(0,3,1,2)#.contiguous() else: twod_bias = None if self.flash and seq_len != 1: # flash attention, 如果序列长度为 1,可能不需要使用 Flash Attention,或者此时使用 Flash Attention 没有意义。 dropout_p = self.dropout if self.training else 0.0 output = F.scaled_dot_product_attention( xq, xk, xv, attn_mask=twod_bias, dropout_p=dropout_p, is_causal=is_causal # true每个位置只能关注到它之前的位置,从而保证模型的因果性。is_causal=True 可以简化代码实现,不需要手动创建和应用因果掩码 ) else: # manual attention # twod_bias = twod_bias.reshape(bsz * self.n_local_heads, seq_len, seq_len) # self.n_local_heads 有点问题 twod_bias = twod_bias.repeat(1, self.n_local_heads, 1, 1) scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim) + twod_bias scores += self.mask[:, :, :seq_len, :seq_len] scores = F.softmax(scores.float(), dim=-1).type_as(xq) scores = self.attn_dropout(scores) output = scores @ xv # print('self.flash and seq_len != 1',self.flash and seq_len != 1,output.shape) output = output.transpose(1, 2).reshape(bsz, seq_len, -1) output = self.resid_dropout(self.wo(output)) return output, past_kv class FeedForward(nn.Module): def __init__(self, config: LMConfig): super().__init__() if config.hidden_dim is None: hidden_dim = 4 * config.dim hidden_dim = int(2 * hidden_dim / 3) config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of) self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False) self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False) self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False) self.dropout = nn.Dropout(config.dropout) def forward(self, x): return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) class MoEGate(nn.Module): def __init__(self, config: LMConfig): super().__init__() self.config = config self.top_k = config.num_experts_per_tok self.n_routed_experts = config.n_routed_experts self.scoring_func = config.scoring_func self.alpha = config.aux_loss_alpha self.seq_aux = config.seq_aux self.norm_topk_prob = config.norm_topk_prob self.gating_dim = config.dim self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim))) self.reset_parameters() def reset_parameters(self) -> None: import torch.nn.init as init init.kaiming_uniform_(self.weight, a=math.sqrt(5)) def forward(self, hidden_states): bsz, seq_len, h = hidden_states.shape hidden_states = hidden_states.view(-1, h) logits = F.linear(hidden_states, self.weight, None) if self.scoring_func == 'softmax': scores = logits.softmax(dim=-1) else: raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}') topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) if self.top_k > 1 and self.norm_topk_prob: denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 topk_weight = topk_weight / denominator if self.training and self.alpha > 0.0: scores_for_aux = scores aux_topk = self.top_k topk_idx_for_aux_loss = topk_idx.view(bsz, -1) if self.seq_aux: scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device) ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_( seq_len * aux_topk / self.n_routed_experts) aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha else: mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts) ce = mask_ce.float().mean(0) Pi = scores_for_aux.mean(0) fi = ce * self.n_routed_experts aux_loss = (Pi * fi).sum() * self.alpha else: aux_loss = 0 return topk_idx, topk_weight, aux_loss class NonLinearHead(nn.Module): """Head for simple classification tasks.""" def __init__( self, input_dim, out_dim, activation_fn, hidden=None, ): super().__init__() hidden = input_dim if not hidden else hidden self.linear1 = nn.Linear(input_dim, hidden) self.linear2 = nn.Linear(hidden, out_dim) self.activation_fn = utils.get_activation_fn(activation_fn) def forward(self, x): x = self.linear1(x) x = self.activation_fn(x) x = self.linear2(x) return x class MOEFeedForward(nn.Module): def __init__(self, config: LMConfig): super().__init__() self.config = config self.experts = nn.ModuleList([ FeedForward(config) for _ in range(config.n_routed_experts) ]) self.gate = MoEGate(config) if config.n_shared_experts is not None: self.shared_experts = FeedForward(config) def forward(self, x): identity = x orig_shape = x.shape bsz, seq_len, _ = x.shape # 使用门控机制选择专家 topk_idx, topk_weight, aux_loss = self.gate(x) x = x.view(-1, x.shape[-1]) flat_topk_idx = topk_idx.view(-1) if self.training: # 训练模式下,重复输入数据 x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0) y = torch.empty_like(x, dtype=torch.float16) for i, expert in enumerate(self.experts): y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致 y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) y = y.view(*orig_shape) else: # 推理模式下,只选择最优专家 y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape) if self.config.n_shared_experts is not None: y = y + self.shared_experts(identity) self.aux_loss = aux_loss return y @torch.no_grad() def moe_infer(self, x, flat_expert_indices, flat_expert_weights): expert_cache = torch.zeros_like(x) idxs = flat_expert_indices.argsort() tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0) token_idxs = idxs // self.config.num_experts_per_tok # 例如当tokens_per_expert=[6, 15, 20, 26, 33, 38, 46, 52] # 当token_idxs=[3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] # 意味着当token_idxs[:6] -> [3, 7, 19, 21, 24, 25, 4]位置的token都由专家0处理,token_idxs[6:15]位置的token都由专家1处理...... for i, end_idx in enumerate(tokens_per_expert): start_idx = 0 if i == 0 else tokens_per_expert[i - 1] if start_idx == end_idx: continue expert = self.experts[i] exp_token_idx = token_idxs[start_idx:end_idx] expert_tokens = x[exp_token_idx] expert_out = expert(expert_tokens).to(expert_cache.dtype) expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) # 使用 scatter_add_ 进行 sum 操作 expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out) return expert_cache class MiniMindBlock(nn.Module): def __init__(self, layer_id: int, config: LMConfig): super().__init__() self.n_heads = config.n_heads self.dim = config.dim self.head_dim = config.dim // config.n_heads self.attention = Attention(config) self.layer_id = layer_id self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps) self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps) self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config) def forward(self, x, pos_cis, twod_tokens=None, past_key_value=None, use_cache=False): # 新增参数 # print(f'forword twod_tokens: {twod_tokens.shape}') h_attn, past_kv = self.attention( self.attention_norm(x), pos_cis, twod_tokens=twod_tokens, # 新增参数 past_key_value=past_key_value, use_cache=use_cache ) h = x + h_attn out = h + self.feed_forward(self.ffn_norm(h)) return out, past_kv class MiniMindLM(PreTrainedModel): # student config_class = LMConfig def __init__(self, params: LMConfig = None): self.params = params or LMConfig() super().__init__(self.params) self.vocab_size, self.n_layers = params.vocab_size, params.n_layers self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim,padding_idx=params.padding_idx) self.dropout = nn.Dropout(params.dropout) self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)]) self.norm = RMSNorm(params.dim, eps=params.norm_eps) self.output = nn.Linear(params.dim, params.logit_dim, bias=False) # self.output_aa = nn.Linear(params.dim, 21, bias=False) # amino acid # self.pool1d = nn.MaxPool1d(kernel_size=3,stride=3,padding=0) self.tok_embeddings.weight = self.output.weight self.register_buffer("pos_cis", precompute_pos_cis(params.dim // params.n_heads, params.max_seq_len, theta=params.rope_theta), persistent=False) self.OUT = CausalLMOutputWithPast() def forward(self, input_ids: Optional[torch.Tensor] = None, twod_tokens: Optional[torch.Tensor] = None, # 新增参数 past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, use_cache: bool = False, **args): past_key_values = past_key_values or [None] * len(self.layers) start_pos = args.get('start_pos', 0) twod_tokens = twod_tokens.to(torch.float32) h = self.dropout(self.tok_embeddings(input_ids)) # set(input_ids.numpy().reshape(-1)), {0, 1, 2, 3, 4, 5, 6, 7, 14, 16, 18, 19, 24} seq_mask = input_ids == 1# padding note seq_mask.unsqueeze_(-1) h = h.masked_fill_(seq_mask, 0) pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)] past_kvs = [] for l, layer in enumerate(self.layers): h, past_kv = layer( h, pos_cis, twod_tokens=twod_tokens, past_key_value=past_key_values[l], use_cache=use_cache ) h = h.masked_fill_(seq_mask, 0) past_kvs.append(past_kv) h = self.norm(h) logits = self.output(h) # h = self.output_aa(h) # h = h.permute(0, 2, 1) # h = self.pool1d(h) # logits_aa = h.permute(0, 2, 1) aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward)) if not h.requires_grad: # 计算非 padding 元素的总和 sum_h = torch.sum(h * ~seq_mask, dim=(1, 2)) # 计算非 padding 元素的数量 count_h = torch.sum(~seq_mask, dim=(1, 2)) # 计算均值 mean_h = sum_h / count_h # 处理特殊情况,如果某个样本的非 padding 元素数量为 0,将该样本的均值设为 0 mean_h[count_h == 0] = 0 # 将均值 reshape 为 (-1, 1) zero_shot = mean_h.reshape(-1, 1) # print(zero_shot.shape,zero_shot) else: zero_shot = None self.OUT.__setitem__('logits', logits) # self.OUT.__setitem__('logits_aa', logits_aa) self.OUT.__setitem__('aux_loss', aux_loss) self.OUT.__setitem__('past_key_values', past_kvs) self.OUT.__setitem__('embeddings', h) self.OUT.__setitem__('zero_shot', zero_shot) # 零样本学习的结果 # print('embeddings',h.shape) return self.OUT @torch.inference_mode() def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90, stream=False, rp=1., use_cache=True, pad_token_id=0, **args): # 流式生成 if stream: return self._generate_stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache) # 直接生成 generated = [] for i in range(input_ids.size(0)): non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0) out = self._generate_stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache) tokens_list = [tokens[:, -1:] for tokens in out] gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad full_sequence = torch.cat([non_pad, gen], dim=-1) generated.append(full_sequence) max_length = max(seq.size(1) for seq in generated) generated = [ torch.cat( [seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)], dim=-1) for seq in generated ] return torch.cat(generated, dim=0) def _generate_stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args): start, first_seq, past_kvs = input_ids.shape[1], True, None while input_ids.shape[1] < max_new_tokens - 1: if first_seq or not use_cache: out, first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache), False else: out = self(input_ids[:, -1:], past_key_values=past_kvs, use_cache=use_cache, start_pos=input_ids.shape[1] - 1) logits, past_kvs = out.logits[:, -1, :], out.past_key_values logits[:, list(set(input_ids.tolist()[0]))] /= rp logits /= (temperature + 1e-9) if top_p is not None and top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) sorted_probs = F.softmax(sorted_logits, dim=-1) cumulative_probs = torch.cumsum(sorted_probs, dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() sorted_indices_to_remove[:, 0] = False indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) logits[indices_to_remove] = -float('Inf') input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) input_ids = torch.cat((input_ids, input_ids_next), dim=1) yield input_ids[:, start:] if input_ids_next.item() == eos_token_id: break