from __future__ import annotations from .configuration_qwen2_hybrid import Qwen2HybridConfig from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint from transformers.cache_utils import Cache from transformers.generation.utils import GenerationMixin from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.utils import add_start_docstrings, logging from transformers.models.qwen2.configuration_qwen2 import Qwen2Config from transformers.models.qwen2.modeling_qwen2 import ( Qwen2Attention, Qwen2MLP, Qwen2PreTrainedModel, Qwen2RMSNorm, Qwen2RotaryEmbedding, apply_rotary_pos_emb, repeat_kv, ) import transformers.models.qwen2.modeling_qwen2 as qwen2_modeling logger = logging.get_logger(__name__) _GQA_LAYERS = set(range(0, 7)) _SHARED_LAYER = 7 _SOFT_MID_LAYERS = set(range(8, 23)) _SOFT_DEEP_LAYERS = set(range(23, 28)) _GQA_SLIDING_WINDOW = 32768 # 前几层的SW为什么这么大 # _SOFT_SLIDING_WINDOW = 4096 _SOFT_SLIDING_WINDOW = 8192 _SHARED_RANK = 320 # hidenstage是1536 _SOFT_RANK_MID = 192 _SOFT_RANK_DEEP = 128 def _layer_role(layer_idx: int) -> str: if layer_idx in _GQA_LAYERS: return "gqa" if layer_idx == _SHARED_LAYER: return "shared_mla" return "soft_mla" def _mla_rank(layer_idx: int) -> int: if layer_idx == _SHARED_LAYER: return _SHARED_RANK if layer_idx in _SOFT_MID_LAYERS: return _SOFT_RANK_MID return _SOFT_RANK_DEEP def _mla_sliding_window(layer_idx: int) -> Optional[int]: return None if layer_idx == _SHARED_LAYER else _SOFT_SLIDING_WINDOW def _mla_zone(layer_idx: int) -> str: if layer_idx in _GQA_LAYERS: return "gqa" if layer_idx == _SHARED_LAYER: return "shared" if layer_idx in _SOFT_MID_LAYERS: return "mid" return "deep" # HybridCache:支持"Attention Sinks"的双模缓存 # 这部分的两个关键:混合缓存管理(HybridCache) 与 跨层特征共享(SharedLatentGate) # HybirdModle主干文件中有实例化HybridCache的代码 class HybridCache(Cache): # 这里继承了hf的Cache类 def __init__(self, config: Qwen2Config): try: super().__init__(layers=config.num_hidden_layers) # 新版本需要传入模型的层数 except TypeError: super().__init__() self.config = config n = config.num_hidden_layers self._gqa_k: List[Optional[torch.Tensor]] = [None] * n # 维度:通常为 [batch, num_kv_heads, seq_len, head_dim] self._gqa_v: List[Optional[torch.Tensor]] = [None] * n self._latent: List[Optional[torch.Tensor]] =[None] * n # 第 7 层的 _latent 还会被 SharedLatentGate 调用,实现跨层特征传递 self._seen_tokens: int = 0 # 记录模型迄今为止已经处理过的Token总数,计算CachePosition和RoPE的关键 # 感觉好多此一举,为什么不直接调用update_gqa函数 def update(self, key_states, value_states, layer_idx, cache_kwargs=None): return self.update_gqa(key_states, value_states, layer_idx) # 返回现在已经处理了多长的序列了 def get_seq_length(self, layer_idx: int = 0) -> int: return self._seen_tokens # 这啥意思? def get_max_cache_shape(self) -> Optional[int]: return None def update_gqa(self, key, value, layer_idx, sliding_window=_GQA_SLIDING_WINDOW): if self._gqa_k[layer_idx] is None: self._gqa_k[layer_idx] = key self._gqa_v[layer_idx] = value else: self._gqa_k[layer_idx] = torch.cat([self._gqa_k[layer_idx], key], dim=2) self._gqa_v[layer_idx] = torch.cat([self._gqa_v[layer_idx], value], dim=2) T = self._gqa_k[layer_idx].shape[2] # seq_len当前历史信息长度 # update_gqa的话只保留最后的sliding_window大小 if T > sliding_window: self._gqa_k[layer_idx] = self._gqa_k[layer_idx][:, :, -sliding_window:, :] self._gqa_v[layer_idx] = self._gqa_v[layer_idx][:, :, -sliding_window:, :] if layer_idx == 0: self._seen_tokens += key.shape[2] # 我对一次输入一个token还能理解,一会儿一次输入一个一会儿一次输出多个这件事不是特别理解 return self._gqa_k[layer_idx], self._gqa_v[layer_idx] # 返回加上了历史信息的KVCache # 我要修改一下这个方法,变成StreamingLLM的思路 # def update_latent(self, c_kv, layer_idx, sliding_window=None): # if self._latent[layer_idx] is None: # self._latent[layer_idx] = c_kv # else: # self._latent[layer_idx] = torch.cat([self._latent[layer_idx], c_kv], dim=1) # if sliding_window is not None: # T = self._latent[layer_idx].shape[1] # if T > sliding_window: # self._latent[layer_idx] = self._latent[layer_idx][:, -sliding_window:, :] # return self._latent[layer_idx] # 更新隐藏状态 def update_latent(self, c_kv, layer_idx, sliding_window=None, sink_size=64): # MLA因为SW比GQA小很多,所以需要sink if self._latent[layer_idx] is None: self._latent[layer_idx] = c_kv else: self._latent[layer_idx] = torch.cat([self._latent[layer_idx], c_kv], dim=1) # latent这里的dim和上面gqa不太一样... if sliding_window is not None: T = self._latent[layer_idx].shape[1] if T > sliding_window: # 🚀 Attention Sinks: 保留头部 sink_size 个 Token,和尾部最新 Token! sink_tokens = self._latent[layer_idx][:, :sink_size, :] # 保留前sink_size个记忆,这段记忆会一直保留,因为每次超出size,获取sink_size获取的都是sink_tokens recent_tokens = self._latent[layer_idx][:, -(sliding_window - sink_size):, :] # 因为加入了sink_tokens所以SW要适当减小 self._latent[layer_idx] = torch.cat([sink_tokens, recent_tokens], dim=1) # 不过感觉这部分有些荣誉计算 return self._latent[layer_idx] # 返回新缓存 # 返回SHARED_LAYER的Cache def get_shared_latent(self) -> Optional[torch.Tensor]: return self._latent[_SHARED_LAYER] # 好像是个移动都某个设备不是特别理解 def to(self, device): # 模型参数一般调用model.to('cuda')还是device就可以移动到显卡了 # 但是Cache类里的张量列表需要手动移动到GPU中确保可以顺利进行计算 for i in range(len(self._gqa_k)): if self._gqa_k[i] is not None: self._gqa_k[i] = self._gqa_k[i].to(device) self._gqa_v[i] = self._gqa_v[i].to(device) if self._latent[i] is not None: self._latent[i] = self._latent[i].to(device) return self # 为了把HybridCache伪装成一个Cache,从而兼容之前的代码逻辑 # 大概理解它的用途,但是不清楚调用和使用时机 class _GQASlotAdapter: def __init__(self, cache: HybridCache, sliding_window: int = _GQA_SLIDING_WINDOW): self._cache = cache self._window = sliding_window def update(self, key_states, value_states, layer_idx, cache_kwargs=None): return self._cache.update_gqa(key_states, value_states, layer_idx, self._window) def get_seq_length(self, layer_idx: int = 0) -> int: return self._cache.get_seq_length(layer_idx) def get_max_cache_shape(self) -> Optional[int]: return None # 主要实现跨层特征通信和平滑微调 # 本质是一个带门控的残差投影器 # 让深层网络能够站在巨人的肩膀上,直接利用已经提取好的特征 class SharedLatentGate(nn.Module): def __init__(self, config: Qwen2Config): super().__init__() H = config.hidden_size self.cross_proj = nn.Linear(_SHARED_RANK, H, bias=False) # 从SHARED_RANK投影会H维度 self.gate = nn.Parameter(torch.full((H,), -4.0)) # H是标量,(H,)是一维向量,每个维度一个独立的门控机制 self.warmup_alpha = nn.Parameter(torch.tensor(0.0)) # warmup_alpha是控制整体的一个加入比列,总阀门 self.norm = Qwen2RMSNorm(H, eps=config.rms_norm_eps) def forward(self, hidden_states, cache=None, explicit_shared=None): # 为了兼容训练/预填充模式和推理生成模式 # 训练或首次输入时会使用explicit_shared if cache is not None and cache.get_shared_latent() is not None: # 这里get_shared_latent是什么意思? shared = cache.get_shared_latent() # 返回第七层截止目前的Cache elif explicit_shared is not None: # 训练时选择显示传参,可以减少频繁读写Cache带来的不必要的开销 shared = explicit_shared else: # else主要是处理 return hidden_states B, T, _ = hidden_states.shape # 这不是当前输入长度吗 T_full = shared.shape[1] # 获取shared info的序列长度 # 🚀 降维打击修复:只提取当前需要的 Token 进行投影,防止历史污染 # 保证长度一致,就是每个ids的token只能获得相同ids token的浅层抽象信息 # 这里其实让我有些疑惑,这样的机制是否真的有用,把浅层的东西往深层直接传递的意义是什么? if T_full != T: shared = shared[:, -T:, :] # 对我们把符合要求的C_kv找出来,然后要把维度从rank扩张会H,因为这个要加到当前输入的token的H上。 proj = self.cross_proj(shared) proj = self.norm(proj) # 制作gate gate_weight = torch.sigmoid(self.gate) * self.warmup_alpha # hidden_states应该是[batchsize,seqlen,dim] return hidden_states + gate_weight.unsqueeze(0).unsqueeze(0) * proj # unsqueeze是解压缩,也有增加维度的意思 class Qwen2MLASoftAttention(nn.Module): def __init__(self, config, layer_idx, kv_lora_rank, sliding_window): super().__init__() self.config = config self.layer_idx = layer_idx self.kv_lora_rank = kv_lora_rank self.sliding_window = sliding_window H = config.hidden_size nh = config.num_attention_heads # config中是12吧,能求出head_dim是128 nkv = config.num_key_value_heads # config中是2,用的也是GQA self.head_dim = getattr(config, "head_dim", H // nh) self.num_heads = nh self.num_kv_heads = nkv self.num_kv_groups = nh // nkv # repeatKV的时候需要这个group的参数,12heads 2kvheads,kv_group就是6(每6个heads一组) self.scaling = self.head_dim ** -0.5 # 缩放系数,通过把方差拉回1来避免,softmax前数据分布太大,导致梯度消失,参数不更新 self.q_proj = nn.Linear(H, nh * self.head_dim, bias=True) self.kv_down_proj = nn.Linear(H, kv_lora_rank, bias=False) # 原本是2 x self.num_kv_heads x self.head_dim = 512 , 这里直接压成了kv_lora_rank{7:320,8~22:192,23~27:128},最后实测表明这里压得有些多了 self.k_up_proj = nn.Linear(kv_lora_rank, nkv * self.head_dim, bias=True) # 把低秩投会全注意力做计算这种合适吗,信息不是还是低秩的吗? self.v_up_proj = nn.Linear(kv_lora_rank, nkv * self.head_dim, bias=True) # 低秩投影回全注意力和GQA复制回全注意力,哪种更好? self.o_proj = nn.Linear(nh * self.head_dim, H, bias=False) # 下面这两个norm是哪里做的? self.k_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.v_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps) # 旋转emb层 self.rotary_emb = Qwen2RotaryEmbedding(config=config) self.output_alpha = nn.Parameter(torch.tensor(0.0)) # 这份代码中主要是一个是KVCache,一个是Mask,一个是postion的问题,不容易想明白 def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], # 这个还有些疑惑,position_embeddings是如何工作的? attention_mask: Optional[torch.Tensor], # 这里传入的mask是4D的形式吗? past_key_values: Optional[HybridCache] = None, # 这个是怎么用? cache_position: Optional[torch.LongTensor] = None, # cache_Position怎么用? full_position_ids: Optional[torch.LongTensor] = None, # 这里还有个position如何用? **kwargs, # 这里有什么参数? ) -> Tuple[torch.Tensor, None]: B, T, H = hidden_states.shape cos, sin = position_embeddings # 还没看内部 # 这里q投影前后都没有进行norm,难道是上一层对输入x进行的norm吗 q = self.q_proj(hidden_states) q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) q, _ = apply_rotary_pos_emb(q, q, cos, sin) # 这个要看一下 # [batch_size,seq_len,kv_latent_dim] c_kv = self.kv_down_proj(hidden_states) # 🚀 终极防切片崩溃修复:独立拼接与缓存 # 这里涉及kvcache的使用,是推理部分的核心,需要我去好好看一下,等下回来我先去看kvcache if past_key_values is not None: past_latent = past_key_values._latent[self.layer_idx] # 这是什么意思,为什么这里获取past还有这个奇怪逻辑 if past_latent is not None: full_c_kv = torch.cat([past_latent, c_kv], dim=1) else: full_c_kv = c_kv past_key_values.update_latent(c_kv, self.layer_idx, sliding_window=self.sliding_window) else: full_c_kv = c_kv T_kv = full_c_kv.shape[1] k = self.k_up_proj(full_c_kv).view(B, T_kv, self.num_kv_heads, self.head_dim) v = self.v_up_proj(full_c_kv).view(B, T_kv, self.num_kv_heads, self.head_dim) # 这里这个norm我不是很理解,为什么要获取kv后进行一次norm,为什么是先norm再transpose k = self.k_norm(k).transpose(1, 2) v = self.v_norm(v).transpose(1, 2) # # 🚀 绝对时空锁定修复:完美支持 bs>1 的 Left-Padding # if full_position_ids is not None: # full_pos_ids = full_position_ids[:, -T_kv:] # 🚀 绝对时空锁定修复:完美支持 bs>1 的 Left-Padding # 下面这三行我也要替换掉 # if full_position_ids is not None: # full_pos_ids = full_position_ids[:, -T_kv:].contiguous() # elif cache_position is not None: # 🚀 绝对时空锁定修复:支持 Attention Sinks 与 Left-Padding S = 64 # Sink 大小,必须与 Cache 中保持一致 # 这个full_position_ids还有些不清楚 if full_position_ids is not None: total_seq_len = full_position_ids.shape[1] # 如果没超过滑动窗口,或者处于 Prefill 阶段 (T_kv == total_seq_len),则直接取尾部 if self.sliding_window is None or total_seq_len <= self.sliding_window or T_kv == total_seq_len: full_pos_ids = full_position_ids[:, -T_kv:].contiguous() else: # 触发 Sink 拼接逻辑:提取头部的 S 个位置,和尾部的残余位置 sink_pos = full_position_ids[:, :S] recent_pos = full_position_ids[:, -(T_kv - S):] full_pos_ids = torch.cat([sink_pos, recent_pos], dim=1).contiguous() elif cache_position is not None: last_abs_pos_t = cache_position[-1] full_pos_ids = (torch.arange(T_kv, device=hidden_states.device, dtype=torch.long) + (last_abs_pos_t + 1 - T_kv)).unsqueeze(0) else: full_pos_ids = torch.arange(T_kv, device=hidden_states.device, dtype=torch.long).unsqueeze(0) # 生成rotary的逻辑也需要好好看一下 cos_k, sin_k = self.rotary_emb(k, full_pos_ids) k, _ = apply_rotary_pos_emb(k, k, cos_k, sin_k) k = repeat_kv(k, self.num_kv_groups) v = repeat_kv(v, self.num_kv_groups) # 这里切换成连续是什么意思? q, k, v = q.contiguous(), k.contiguous(), v.contiguous() # kv_seq_len = k.shape[2] # if attention_mask is not None and attention_mask.shape[-1] > kv_seq_len: # attention_mask = attention_mask[..., :, -kv_seq_len:] # 修改后逻辑,加contiguous # kv_seq_len = k.shape[2] # if attention_mask is not None and attention_mask.shape[-1] > kv_seq_len: # attention_mask = attention_mask[..., :, -kv_seq_len:].contiguous() # 下面这里也是我新修改的,稍微有些难理解,和sink有关系 kv_seq_len = k.shape[2] if attention_mask is not None and attention_mask.shape[-1] > kv_seq_len: total_mask_len = attention_mask.shape[-1] if self.sliding_window is None or total_mask_len <= self.sliding_window or kv_seq_len == total_mask_len: attention_mask = attention_mask[..., :, -kv_seq_len:].contiguous() else: # 🚀 掩码也要同步拼接 Sink sink_mask = attention_mask[..., :, :S] recent_mask = attention_mask[..., :, -(kv_seq_len - S):] attention_mask = torch.cat([sink_mask, recent_mask], dim=-1).contiguous() is_causal = True if (attention_mask is None and T > 1) else False out = F.scaled_dot_product_attention( q, k, v, attn_mask=attention_mask, dropout_p=0.0, is_causal=is_causal, scale=self.scaling ) out = out.transpose(1, 2).contiguous().view(B, T, -1) out = self.o_proj(out) * self.output_alpha return out, c_kv # 上一个层self.layers就是堆叠了一堆decoder class Qwen2HybridDecoderLayer(nn.Module): def __init__(self, config: Qwen2Config, layer_idx: int): super().__init__() self.layer_idx = layer_idx self.layer_role = _layer_role(layer_idx) if self.layer_role == "gqa": attn_impl = getattr(config, "_attn_implementation", "sdpa") attn_class = getattr(qwen2_modeling, "QWEN2_ATTENTION_CLASSES", {}).get(attn_impl, Qwen2Attention) self.self_attn = attn_class(config=config, layer_idx=layer_idx) else: self.self_attn = Qwen2MLASoftAttention( config=config, layer_idx=layer_idx, kv_lora_rank=_mla_rank(layer_idx), sliding_window=_mla_sliding_window(layer_idx) ) self.shared_gate = SharedLatentGate(config) if self.layer_role == "soft_mla" else None self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, cache_position=None, position_embeddings=None, output_attentions=False, shared_latent=None, full_position_ids=None, **kwargs, ): if self.shared_gate is not None: # 在模型的前 6 层,为了兼容 GQA,传入的是 _GQASlotAdapter # 不是很理解这里的适配,前六层既然是适配器了,为什么还需要调用sharedgate real_cache = past_key_values._cache if isinstance(past_key_values, _GQASlotAdapter) else past_key_values # 这里的real_cache是一个HybridCache对象 hidden_states = self.shared_gate(hidden_states, cache=real_cache, explicit_shared=shared_latent) # Decoder的前半部分mid_output = x + Atten(Norm(x)) residual = hidden_states # 一个decoder要进行残差链接的 normed_input = self.input_layernorm(hidden_states) # Attention前做了input_norm了 # 为什么gqa传的position_ids,mla传的是full_position_ids if self.layer_role == "gqa": attn_outputs = self.self_attn( hidden_states=normed_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, # gqa的位置信息已经被处理过一部分了,是增量处理 ) if len(attn_outputs) == 3: attn_out, _, past_key_values = attn_outputs elif len(attn_outputs) == 2: attn_out, past_key_values = attn_outputs else: attn_out = attn_outputs[0]; past_key_values = None hidden_states = attn_out else: attn_out, c_kv = self.self_attn( hidden_states=normed_input, position_embeddings=position_embeddings, attention_mask=attention_mask, past_key_values=past_key_values, cache_position=cache_position, full_position_ids=full_position_ids, # mla需要全量处理所有位置信息(),是全量处理 ) hidden_states = attn_out if self.layer_role == "shared_mla": shared_latent = c_kv hidden_states = residual + hidden_states # 下面是标准Decoder的后半块,output = x + MLP(Norm(x)) residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states, shared_latent # 返回残差块输出hidden_states可以理解,但shared_latent是什么意思,是训练时的显示串联吗? @add_start_docstrings("Qwen2.5-Coder 非对称混合架构主干,v9。") class Qwen2HybridModel(Qwen2PreTrainedModel): config_class = Qwen2HybridConfig # <--- 就是缺了这一行! def __init__(self, config: Qwen2HybridConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList([Qwen2HybridDecoderLayer(config, i) for i in range(config.num_hidden_layers)]) self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen2RotaryEmbedding(config=config) self.gradient_checkpointing = False self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value def forward( self, input_ids=None, attention_mask=None, position_ids=None, past_key_values=None, inputs_embeds=None, use_cache=None, cache_position=None, output_attentions=False, output_hidden_states=False, return_dict=True, **kwargs, ): # 输入处理 if (input_ids is None) == (inputs_embeds is None): raise ValueError("必须且只能指定 input_ids 或 inputs_embeds 之一") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) B, T, _ = inputs_embeds.shape # 判断是否使用Cache,如果使用且没有创建合适类型就在这里创建 if use_cache: if not isinstance(past_key_values, HybridCache): past_key_values = HybridCache(config=self.config) # 生成当前输入token在整个序列中的"绝对位置索引流水号" # cache_postion是给新来的每个Token分配的唯一门牌号,有些迷惑 if cache_position is None: past_seen = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange(past_seen, past_seen + T, device=inputs_embeds.device) if position_ids is None: position_ids = cache_position.unsqueeze(0) # # 🚀 绝对时空锁定:提取真实的 Position IDs,完美解决 Left-Padding 导致的 RoPE 错位! # if getattr(self.config, "_attn_implementation", "sdpa") == "sdpa" and not output_attentions and attention_mask is None: # causal_4d = None # full_position_ids = None # else: # past_kv_len = int(cache_position[0].item()) if T > 0 else 0 # causal_4d = _prepare_4d_causal_attention_mask( # attention_mask, (B, T), inputs_embeds, past_kv_len, sliding_window=None # ) # if attention_mask is not None and attention_mask.dim() == 2: # full_position_ids = attention_mask.long().cumsum(-1) - 1 # full_position_ids = full_position_ids.masked_fill(attention_mask == 0, 1) # else: # full_position_ids = None # 🚀 绝对时空锁定:提取真实的 Position IDs,完美解决 Left-Padding 导致的 RoPE 错位! # 解决Left-Padding导致的位移偏差 # 下面这部分代码只有预填充阶段进行,会根据attention_mask的情况计算每个token在序列中的绝对位置,同时能够处理好Left-Padding # 训练阶段是不是也一直走这部分逻辑,但是我传入的bin文件,是如何产生attention_mask的? if attention_mask is not None and attention_mask.dim() == 2: # !只有预填充时mask才是2d,推理Decoder到之后传递的就变成4d的mask了 full_position_ids = attention_mask.long().cumsum(-1) - 1 # 前缀和累加+索引对齐 full_position_ids = full_position_ids.masked_fill(attention_mask == 0, 1) # else: full_position_ids = None # 🌟 新增拦截器:如果 mask 存在但全是 1(无 padding),强行设为 None,保住 Flash Attention! # attention_mask是一个2d的提示器,主要适用于识别padding的,全1说明没有Padding is_all_ones = (attention_mask is None) or (attention_mask.min() == 1) # output_attentions是布尔开关,是否需要每层计算出注意力权重(应该是用来调试的,观察每层的状态) if getattr(self.config, "_attn_implementation", "sdpa") == "sdpa" and not output_attentions and is_all_ones: causal_4d = None # 没有padding直接用None,启用sdpa内部的causal mask逻辑 else: # 这里的意思是,如果没有加速,或者说就是需要使用自定义mask,走下面的逻辑 past_kv_len = int(cache_position[0].item()) if T > 0 else 0 # 将2d的attention_mask转成4d的mask张量 causal_4d = _prepare_4d_causal_attention_mask( attention_mask, (B, T), inputs_embeds, past_kv_len, sliding_window=None ) hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) # 构建一个gqa适配器,给前六层用,7层以后的模型直接用past_key_values就行 # 主要是因为前六层调用的是Transformers库里的Attention所以得把HybridCache封装的和之前的DynamicCache一样 gqa_adapter = _GQASlotAdapter(past_key_values) if past_key_values is not None else None all_hidden_states = () if output_hidden_states else None shared_latent = None # 这里是按层遍历的逻辑 for layer in self.layers: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) effective_cache = gqa_adapter if layer.layer_role == "gqa" else past_key_values if self.gradient_checkpointing and self.training: if cache_position is not None: assert cache_position.device == inputs_embeds.device outputs = torch.utils.checkpoint.checkpoint( layer, hidden_states, causal_4d, position_ids, None, False, cache_position, position_embeddings, output_attentions, shared_latent, full_position_ids, use_reentrant=False, ) hidden_states, shared_latent = outputs[0], outputs[1] else: outputs = layer( hidden_states, attention_mask=causal_4d, position_ids=position_ids, past_key_values=effective_cache, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, output_attentions=output_attentions, shared_latent=shared_latent, full_position_ids=full_position_ids, ) hidden_states, shared_latent = outputs[0], outputs[1] # 遍历完要进行一下norm这里是RMSnorm hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple(v for v in[hidden_states, past_key_values if use_cache else None, all_hidden_states, None] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=None, ) class Qwen2HybridForCausalLM(Qwen2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] config_class = Qwen2HybridConfig # <--- 就是缺了这一行! def __init__(self, config: Qwen2HybridConfig): super().__init__(config) self.model = Qwen2HybridModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.post_init() def _init_weights(self, module: nn.Module): super()._init_weights(module) if isinstance(module, Qwen2MLASoftAttention): nn.init.zeros_(module.output_alpha) elif isinstance(module, SharedLatentGate): nn.init.zeros_(module.warmup_alpha) nn.init.constant_(module.gate, -4.0) def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model def forward( self, input_ids=None, attention_mask=None, position_ids=None, past_key_values=None, inputs_embeds=None, labels=None, use_cache=None, cache_position=None, output_attentions=False, output_hidden_states=False, return_dict=True, **kwargs, ) -> Union[CausalLMOutputWithPast, Tuple]: outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, ) hidden_states = outputs.last_hidden_state logits = self.lm_head(hidden_states).float() loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = F.cross_entropy(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1), ignore_index=-100) if not return_dict: out = (logits,) if use_cache: out = out + (outputs.past_key_values,) if output_hidden_states: out = out + (outputs.hidden_states,) return ((loss,) + out) if loss is not None else out return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs, ) -> dict: past_len = past_key_values.get_seq_length() if past_key_values is not None else 0 if past_len > 0: if inputs_embeds is not None: inputs_embeds = inputs_embeds[:, -1:] else: input_ids = input_ids[:, -1:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids = position_ids.masked_fill(attention_mask == 0, 1) if past_len > 0: position_ids = position_ids[:, -input_ids.shape[1]:] # 好像是decode的生成阶段执行的 if cache_position is None: cache_position = torch.arange(past_len, past_len + input_ids.shape[1], device=input_ids.device) model_inputs = {} if inputs_embeds is not None and past_len == 0: model_inputs["inputs_embeds"] = inputs_embeds else: model_inputs["input_ids"] = input_ids model_inputs.update({ "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache", True), "attention_mask": attention_mask, "position_ids": position_ids, "cache_position": cache_position, }) return model_inputs @staticmethod def _reorder_cache(past_key_values, beam_idx): for i in range(len(past_key_values._gqa_k)): if past_key_values._gqa_k[i] is not None: past_key_values._gqa_k[i] = past_key_values._gqa_k[i].index_select(0, beam_idx) past_key_values._gqa_v[i] = past_key_values._gqa_v[i].index_select(0, beam_idx) if past_key_values._latent[i] is not None: past_key_values._latent[i] = past_key_values._latent[i].index_select(0, beam_idx) return past_key_values def _svd_project_kv(k_weight, v_weight, kv_rank, k_bias=None, v_bias=None): nkv_d = k_weight.shape[0] orig_dtype = k_weight.dtype M = torch.cat([k_weight, v_weight], dim=0).float() U, S, Vh = torch.linalg.svd(M, full_matrices=False) r = min(kv_rank, S.shape[0]) S_sqrt = S[:r].sqrt().unsqueeze(0) down_w = Vh[:r, :].to(orig_dtype) k_up_w = (U[:nkv_d, :r] * S_sqrt).to(orig_dtype) v_up_w = (U[nkv_d:, :r] * S_sqrt).to(orig_dtype) k_up_bias = k_bias.to(orig_dtype) if k_bias is not None else None v_up_bias = v_bias.to(orig_dtype) if v_bias is not None else None return down_w, k_up_w, v_up_w, k_up_bias, v_up_bias def migrate_weights_from_qwen2(hybrid_model, original_state_dict, svd_verbose=True): hybrid_sd = hybrid_model.state_dict() new_sd, unmapped = {},[] layer_kv = {} for orig_key, orig_val in original_state_dict.items(): if not orig_key.startswith("model.layers."): continue parts = orig_key.split(".") layer_idx = int(parts[2]) suffix = ".".join(parts[3:]) if _layer_role(layer_idx) == "gqa": continue if suffix == "self_attn.k_proj.weight": layer_kv.setdefault(layer_idx, {})["k_w"] = orig_val elif suffix == "self_attn.v_proj.weight": layer_kv.setdefault(layer_idx, {})["v_w"] = orig_val elif suffix == "self_attn.k_proj.bias": layer_kv.setdefault(layer_idx, {})["k_b"] = orig_val elif suffix == "self_attn.v_proj.bias": layer_kv.setdefault(layer_idx, {})["v_b"] = orig_val for orig_key, orig_val in original_state_dict.items(): if not orig_key.startswith("model.layers."): if orig_key in hybrid_sd: new_sd[orig_key] = orig_val else: unmapped.append(orig_key) continue parts = orig_key.split(".") layer_idx = int(parts[2]) suffix = ".".join(parts[3:]) role = _layer_role(layer_idx) tgt = f"model.layers.{layer_idx}.{suffix}" if role == "gqa": if tgt in hybrid_sd: new_sd[tgt] = orig_val else: unmapped.append(orig_key) continue if suffix in ("self_attn.q_proj.weight", "self_attn.q_proj.bias"): if tgt in hybrid_sd: new_sd[tgt] = orig_val elif suffix in ("self_attn.k_proj.weight", "self_attn.v_proj.weight", "self_attn.k_proj.bias", "self_attn.v_proj.bias"): pass elif suffix == "self_attn.o_proj.weight": if tgt in hybrid_sd and hybrid_sd[tgt].shape == orig_val.shape: new_sd[tgt] = orig_val else: unmapped.append(f"{orig_key} [shape mismatch or missing]") elif "mlp." in suffix or "layernorm" in suffix: if tgt in hybrid_sd: new_sd[tgt] = orig_val else: unmapped.append(orig_key) svd_done, svd_errors = 0,[] for layer_idx in sorted(layer_kv.keys()): kv = layer_kv[layer_idx] k_w, v_w = kv.get("k_w"), kv.get("v_w") if k_w is None or v_w is None: svd_errors.append(f"Layer {layer_idx}: 缺少 k_w 或 v_w") continue rank = _mla_rank(layer_idx) zone = _mla_zone(layer_idx) k_b, v_b = kv.get("k_b"), kv.get("v_b") if svd_verbose: bias_info = "w/ bias" if k_b is not None else "no bias" print(f" [SVD] Layer {layer_idx:2d} [{zone:6s}] k{list(k_w.shape)} + v{list(v_w.shape)} → rank={rank:3d} ({bias_info})") try: down_w, k_up_w, v_up_w, k_up_b, v_up_b = _svd_project_kv(k_w, v_w, rank, k_bias=k_b, v_bias=v_b) except Exception as exc: svd_errors.append(f"Layer {layer_idx}: SVD failed — {exc}") continue pfx = f"model.layers.{layer_idx}.self_attn" for key, weight in[(f"{pfx}.kv_down_proj.weight", down_w), (f"{pfx}.k_up_proj.weight", k_up_w), (f"{pfx}.v_up_proj.weight", v_up_w)]: if key in hybrid_sd and hybrid_sd[key].shape == weight.shape: new_sd[key] = weight else: svd_errors.append(f"{key}: shape mismatch") for key, bias_val in[(f"{pfx}.k_up_proj.bias", k_up_b), (f"{pfx}.v_up_proj.bias", v_up_b)]: if bias_val is not None and key in hybrid_sd: if hybrid_sd[key].shape == bias_val.shape: new_sd[key] = bias_val svd_done += 1 custom_written = 0 for key in hybrid_sd: if key.endswith(".self_attn.output_alpha"): new_sd[key] = torch.tensor(0.0) custom_written += 1 elif key.endswith(".shared_gate.warmup_alpha"): new_sd[key] = torch.tensor(0.0) custom_written += 1 elif key.endswith(".shared_gate.gate"): new_sd[key] = torch.full(hybrid_sd[key].shape, -4.0) custom_written += 1 missing, unexpected = hybrid_model.load_state_dict(new_sd, strict=False) if svd_verbose: sep = "=" * 65 print(f"\n{sep}\n[migrate_weights_v9] Qwen2 → Hybrid v9 迁移完成\n{sep}") print(f" Rank: shared(L7)={_SHARED_RANK} | mid(L8-22)={_SOFT_RANK_MID} | deep(L23-27)={_SOFT_RANK_DEEP}") print(f" SVD 热启动 : {svd_done} 层\n 自定义参数写入 : {custom_written} 个\n 总写入 keys : {len(new_sd)}") print(f" 缺失(新增模块) : {len(missing):3d}\n 意外(多余) : {len(unexpected):3d}\n 未映射原始 keys : {len(unmapped):3d}") if svd_errors: for e in svd_errors: print(f" ⚠ {e}") print(f"{sep}\n") return unmapped def get_alpha_param_groups(model, base_lr, alpha_lr_scale=10.0): alpha_params, base_params, alpha_names = [], [],[] for name, param in model.named_parameters(): if not param.requires_grad: continue if name.endswith(".self_attn.output_alpha") or name.endswith(".shared_gate.warmup_alpha"): alpha_params.append(param) alpha_names.append(name) else: base_params.append(param) print(f"[get_alpha_param_groups]\n Base params : {len(base_params):4d} lr={base_lr:.2e}\n Alpha params : {len(alpha_params):4d} lr={base_lr * alpha_lr_scale:.2e}") return[{"params": base_params, "lr": base_lr, "name": "base"}, {"params": alpha_params, "lr": base_lr * alpha_lr_scale, "name": "alpha_gate"}] def verify_no_nan(model): nan_params =[f" ✗ NaN in {n} shape={list(p.shape)}" for n, p in model.named_parameters() if p.data.isnan().any()] if nan_params: print("[verify_no_nan] 发现 NaN 参数:\n" + "\n".join(nan_params)) return False print(f"[verify_no_nan] ✓ 所有 {sum(1 for _ in model.parameters())} 个参数均无 NaN") return True def verify_alpha_zero(model): problems =[] for name, param in model.named_parameters(): if name.endswith(".self_attn.output_alpha") or name.endswith(".shared_gate.warmup_alpha"): if abs(param.item()) > 1e-6: problems.append(f" ✗ {name} = {param.item():.6f}(应为 0.0)") if problems: print("[verify_alpha_zero] Alpha 初始化异常:\n" + "\n".join(problems)) return False print("[verify_alpha_zero] ✓ 所有 output_alpha / warmup_alpha = 0.0") return True __all__ =[ "_SHARED_RANK", "_SOFT_RANK_MID", "_SOFT_RANK_DEEP", "_layer_role", "_mla_rank", "_mla_zone", "_mla_sliding_window", "_svd_project_kv", "HybridCache", "SharedLatentGate", "Qwen2MLASoftAttention", "Qwen2HybridDecoderLayer", "Qwen2HybridModel", "Qwen2HybridForCausalLM", "migrate_weights_from_qwen2", "get_alpha_param_groups", "verify_no_nan", "verify_alpha_zero", ]