import torch import torch.nn as nn import torch.nn.functional as F import math from typing import Optional, Tuple, Union, List from dataclasses import dataclass from transformers import PretrainedConfig from transformers.modeling_utils import PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast from transformers.models.llama.modeling_llama import ( LlamaRMSNorm, LlamaMLP, LlamaAttention, LlamaRotaryEmbedding, apply_rotary_pos_emb, ) # ============================================================================ # Configuration # ============================================================================ class NanoHammerConfig(PretrainedConfig): model_type = "nanohammer" def __init__( self, vocab_size=128256, hidden_size=2048, intermediate_size=8192, num_hidden_layers=24, num_attention_heads=32, num_key_value_heads=8, # GQA num_state_heads=32, state_hidden_size=None, # 默认为 hidden_size * 4 max_position_embeddings=131072, rms_norm_eps=1e-5, initializer_range=0.02, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, attention_bias=False, attention_dropout=0.0, mlp_bias=False, hidden_act="silu", # NanoHammer 特有参数 bos_token_id=128000, eos_token_id=128009, pad_token_id=None, **kwargs ): # 设置 auto_map 以支持 trust_remote_code if "auto_map" not in kwargs: kwargs["auto_map"] = { "AutoConfig": "NanoHammerForCausalLM.NanoHammerConfig", "AutoModelForCausalLM": "NanoHammerForCausalLM.NanoHammerForCausalLM", } self.vocab_size = vocab_size self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.num_state_heads = num_state_heads # 积分状态维度:默认为 hidden_size / 4,提升状态表征能力 self.state_hidden_size = state_hidden_size if state_hidden_size is not None else hidden_size / 4 self.max_position_embeddings = max_position_embeddings self.rms_norm_eps = rms_norm_eps self.initializer_range = initializer_range self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.mlp_bias = mlp_bias self.hidden_act = hidden_act super().__init__( tie_word_embeddings=tie_word_embeddings, bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs ) # ============================================================================ # 创新组件 1: 全息旋转位置编码 # ============================================================================ class HolographicRotaryEmbedding(nn.Module): """ 全息旋转位置编码 - 为积分状态注入时间特征 核心思想: - 对每个位置 i,应用复数域旋转:x_i * e^(i*θ_k) - 积分后:S_t = Σ(x_i * e^(i*θ_k)),状态成为"多项式系数容器" - 通过逆旋转 R_{-t} 转换为相对坐标系,实现平移不变性 关键修正:使用绝对 position_ids 而非相对 seq_len """ def __init__(self, dim, max_position_embeddings=131072, base=10000): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base # 计算频率:θ_k = base^(-2k/d) inv_freq = 1.0 / (self.base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) def forward(self, x, position_ids): """ 应用旋转位置编码(使用绝对位置) Args: x: (B, T, D) - 输入张量 position_ids: (B, T) - 绝对位置索引 Returns: x_rotated: (B, T, D) - 应用旋转编码后的张量 """ # position_ids: (B, T) -> (B, T, 1) # inv_freq: (D/2,) -> (1, 1, D/2) # freqs: (B, T, D/2) freqs = torch.einsum("bt,d->btd", position_ids.to(x.dtype), self.inv_freq.to(x.dtype)) # 计算 cos 和 sin:(B, T, D/2) cos = freqs.cos() sin = freqs.sin() # 扩展到完整维度:(B, T, D) cos = torch.cat([cos, cos], dim=-1) sin = torch.cat([sin, sin], dim=-1) # 应用旋转变换 x1 = x[..., 0::2] x2 = x[..., 1::2] x1_rotated = x1 * cos[..., 0::2] - x2 * sin[..., 0::2] x2_rotated = x1 * sin[..., 1::2] + x2 * cos[..., 1::2] x_rotated = torch.stack([x1_rotated, x2_rotated], dim=-1).flatten(-2) return x_rotated def apply_inverse_rotation(self, x, position_ids): """ 应用逆旋转,转换为相对坐标系(使用绝对位置) 核心:S_t' = S_t * e^(-t*θ),将积分状态转换为相对视角 Args: x: (B, T, D) - 积分状态张量 position_ids: (B, T) - 绝对位置索引 Returns: x_relative: (B, T, D) - 相对坐标系下的状态 """ # 使用绝对位置计算逆旋转 freqs = torch.einsum("bt,d->btd", position_ids.to(x.dtype), self.inv_freq.to(x.dtype)) # 计算 cos(-θ) 和 sin(-θ) cos = freqs.cos() sin = -freqs.sin() # 负号!逆旋转的关键 cos = torch.cat([cos, cos], dim=-1) sin = torch.cat([sin, sin], dim=-1) # 应用逆旋转 x1 = x[..., 0::2] x2 = x[..., 1::2] x1_relative = x1 * cos[..., 0::2] + x2 * sin[..., 0::2] x2_relative = -x1 * sin[..., 1::2] + x2 * cos[..., 1::2] x_relative = torch.stack([x1_relative, x2_relative], dim=-1).flatten(-2) return x_relative # ============================================================================ # 创新组件 2: 多头状态更新单元 # ============================================================================ class StateUpdateCell(nn.Module): """ Multi-Head State Update Cell - 欧拉法固定点迭代 在全息积分状态上进行非线性演化: - S_{t+1} = S_t + α·f(S_t) - 每个头在独立子空间迭代 - 可学习步长 α """ def __init__(self, config): super().__init__() self.hidden_size = config.state_hidden_size self.num_heads = config.num_state_heads self.head_dim = config.state_hidden_size // config.num_state_heads assert config.state_hidden_size % config.num_state_heads == 0 # Pre-Norm self.pre_norm = LlamaRMSNorm(config.state_hidden_size, eps=config.rms_norm_eps) # MLP 更新函数 self.mlp = nn.Sequential( nn.Linear(config.state_hidden_size, config.state_hidden_size * 4, bias=False), nn.SiLU(), nn.Linear(config.state_hidden_size * 4, config.state_hidden_size, bias=False) ) # Post-Norm self.post_norm = LlamaRMSNorm(config.state_hidden_size, eps=config.rms_norm_eps) # 欧拉法步长(可学习,每个头独立) self.step_size = nn.Parameter(torch.ones(self.num_heads) * 0.1) def forward(self, state): """ 欧拉法更新:S_{t+1} = S_t + α * f(S_t) Args: state: (B, T, state_hidden_size) Returns: state: (B, T, state_hidden_size) """ batch_size, seq_len, _ = state.shape # Pre-Norm state_normed = self.pre_norm(state) # MLP 计算增量 delta = self.mlp(state_normed) # (B, T, state_hidden_size) # Reshape 为多头 state_heads = state.view(batch_size, seq_len, self.num_heads, self.head_dim) delta_heads = delta.view(batch_size, seq_len, self.num_heads, self.head_dim) # 每个头独立步长更新 step_size = self.step_size.view(1, 1, self.num_heads, 1) state_heads = state_heads + step_size * delta_heads # Merge heads state = state_heads.view(batch_size, seq_len, self.hidden_size) # Post-Norm state = self.post_norm(state) return state # ============================================================================ # 创新组件 4: State Token 投影 # ============================================================================ class StateTokenProjection(nn.Module): """ State -> Token 投影:将全息积分状态投影为 hidden_size 维度的 token 这个 state token 将被添加到序列开头,参与标准的 Llama attention """ def __init__(self, config): super().__init__() self.state_to_hidden = nn.Linear(config.state_hidden_size, config.hidden_size, bias=False) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward(self, state: torch.Tensor) -> torch.Tensor: """ Args: state: (B, T, state_hidden_size) - 全息积分状态 Returns: state_token: (B, T, hidden_size) - 投影后的 state token """ state_token = self.state_to_hidden(state) state_token = self.norm(state_token) return state_token # ============================================================================ # 创新组件 5: 混合 Decoder Layer(State Token + Llama Attention) # ============================================================================ class HybridNanoHammerDecoderLayer(nn.Module): """ NanoHammer 解码器层:因果 State Tokens 前缀 + 标准 Llama Attention + MLP 流程: 1. State 更新:非线性演化全息积分状态 2. 因果 State Tokens 生成: - 对于位置 i,生成 state_token_i(包含截止到 i-1 的累积状态) - 所有 state tokens 作为前缀:[state_0, state_1, ..., state_{T-1}, hidden_0, ..., hidden_{T-1}] - 序列长度变为 2T 3. 特殊 Attention Mask: - hidden_i 可以 attend 到:state_j (j <= i) 和 hidden_j (j < i) - 确保因果性:位置 i 只能看到它之前的历史状态 4. Self-Attention:标准 Llama attention(序列长度 = 2T) 5. MLP:前馈网络 6. 移除 state tokens:只返回 hidden states 部分 关键设计: - State tokens 作为"全局前缀",每个对应一个位置的历史 - 通过精心设计的 attention mask 确保因果性 - 优雅地保留了"state token 在前"的架构理念 增量生成模式 (O(1) per step): - 只处理一个新 token + 一个累积 state token - state 已经包含了所有历史信息(全息积分的累积性质) - 复杂度:O(1) per token generation step """ def __init__(self, config, layer_idx: int, holographic_rope: HolographicRotaryEmbedding): super().__init__() self.hidden_size = config.hidden_size self.layer_idx = layer_idx self.holographic_rope = holographic_rope # 0. State 更新单元 self.state_cell = StateUpdateCell(config) # 1. State Token 投影 self.state_projection = StateTokenProjection(config) # 2. Self-Attention(标准 Llama) self.self_attn = LlamaAttention(config, layer_idx) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) # Llama RoPE for position embeddings self.rotary_emb = LlamaRotaryEmbedding(config=config) # 3. MLP self.mlp = LlamaMLP(config) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, state: torch.Tensor, position_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, past_state: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple]] = None, use_cache: bool = False, **kwargs ) -> Tuple[torch.Tensor, torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple]]]: """ Args: hidden_states: (B, T, hidden_size) state: (B, T_state, state_hidden_size) position_ids: (B, T) attention_mask: (B, 1, T, T) past_state: 元组 (state_last, state_tokens, position_ids, kv_cache) 或 None use_cache: bool Returns: hidden_states, state, present_state """ batch_size, seq_len, _ = hidden_states.shape # 检测增量模式:past_state 存在且只有一个新 token is_incremental = past_state is not None and seq_len == 1 if is_incremental: state_last, state_tokens_cache, position_ids_cache, kv_cache = past_state return self._forward_incremental( hidden_states, state, position_ids, state_last, state_tokens_cache, position_ids_cache, kv_cache ) else: return self._forward_prefill( hidden_states, state, position_ids, attention_mask, use_cache ) def _forward_incremental( self, hidden_states: torch.Tensor, # (B, 1, hidden_size) state: torch.Tensor, # (B, 1, state_hidden_size) - 当前层的 state 输入 position_ids: torch.LongTensor, # (B, 1) past_state: torch.Tensor, # (B, 1, state_hidden_size) - 上一步该层的 state_cell 输出 past_state_tokens: torch.Tensor, # (B, T_past, hidden_size) - 历史 state_tokens past_position_ids: torch.LongTensor, # (B, T_past) - 历史位置 past_kv_cache: Tuple[torch.Tensor, torch.Tensor], # (keys, values) from previous steps ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple]]: """ 增量生成:使用 KV cache 实现 O(T) 每步复杂度 缓存结构: (state_last, state_tokens, position_ids, kv_cache) 复杂度: O(T) 每步 (attention only, QKV projection is O(1)) """ batch_size = hidden_states.shape[0] device = hidden_states.device dtype = hidden_states.dtype T_past = past_state_tokens.shape[1] # 1. State 非线性更新 state_updated = self.state_cell(state) # (B, 1, state_hidden_size) # 2. 生成当前位置的 state_token(使用 past_state,保证因果性) state_relative = self.holographic_rope.apply_inverse_rotation(past_state, position_ids) new_state_token = self.state_projection(state_relative) # (B, 1, hidden_size) # 3. 更新 state_tokens 缓存 (用于下一步) all_state_tokens = torch.cat([past_state_tokens, new_state_token], dim=1) all_position_ids = torch.cat([past_position_ids, position_ids], dim=1) # 4. 构建当前步的输入:只有 2 个 token [new_state_token, new_hidden_token] query_input = torch.cat([new_state_token, hidden_states], dim=1) # (B, 2, hidden_size) # 5. Position IDs for query: state 和 hidden 在同一位置 query_position_ids = torch.cat([position_ids, position_ids], dim=1) # (B, 2) # 6. 构建 attention mask # KV cache 布局: [state_0, ..., state_{T-1}, hidden_0, ..., hidden_{T-1}] # Query 布局: [state_T, hidden_T] # Query 0 (state_T): 看 state_0..state_T (indices 0..T in KV after concat) # Query 1 (hidden_T): 看 state_0..state_T + hidden_0..hidden_{T-1} + self kv_len = 2 * T_past + 2 # past states + past hiddens + new state + new hidden attention_mask = torch.full( (batch_size, 1, 2, kv_len), torch.finfo(dtype).min, dtype=dtype, device=device ) # state_T 可以看: past states (0..T_past-1) + new state (2*T_past) attention_mask[:, :, 0, :T_past] = 0 # past states attention_mask[:, :, 0, 2*T_past] = 0 # new state (self) # hidden_T 可以看: all states + past hiddens + new hidden (self) attention_mask[:, :, 1, :T_past] = 0 # past states attention_mask[:, :, 1, T_past:2*T_past] = 0 # past hiddens attention_mask[:, :, 1, 2*T_past] = 0 # new state attention_mask[:, :, 1, 2*T_past+1] = 0 # new hidden (self) # 7. LayerNorm residual = query_input query_normed = self.input_layernorm(query_input) # 8. QKV 投影 head_dim = self.self_attn.head_dim num_heads = self.self_attn.config.num_attention_heads num_kv_heads = self.self_attn.config.num_key_value_heads num_kv_groups = num_heads // num_kv_heads query_states = self.self_attn.q_proj(query_normed) key_states = self.self_attn.k_proj(query_normed) value_states = self.self_attn.v_proj(query_normed) # Reshape: (B, 2, H*D) -> (B, H, 2, D) query_states = query_states.view(batch_size, 2, num_heads, head_dim).transpose(1, 2) key_states = key_states.view(batch_size, 2, num_kv_heads, head_dim).transpose(1, 2) value_states = value_states.view(batch_size, 2, num_kv_heads, head_dim).transpose(1, 2) # 9. RoPE position_embeddings = self.rotary_emb(query_normed, query_position_ids) cos, sin = position_embeddings from transformers.models.llama.modeling_llama import apply_rotary_pos_emb query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # 10. Concat KV cache past_keys, past_values = past_kv_cache key_states = torch.cat([past_keys, key_states], dim=2) value_states = torch.cat([past_values, value_states], dim=2) # 11. Attention # Expand KV for GQA key_states_expanded = key_states.repeat_interleave(num_kv_groups, dim=1) value_states_expanded = value_states.repeat_interleave(num_kv_groups, dim=1) attn_weights = torch.matmul(query_states, key_states_expanded.transpose(2, 3)) * (head_dim ** -0.5) attn_weights = attn_weights + attention_mask attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states_expanded) # 12. Output projection attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, 2, -1) attn_output = self.self_attn.o_proj(attn_output) combined_output = residual + attn_output # 13. MLP residual = combined_output combined_output = self.post_attention_layernorm(combined_output) combined_output = self.mlp(combined_output) combined_output = residual + combined_output # 14. 提取 hidden state (第二个位置) hidden_output = combined_output[:, 1:, :] # (B, 1, hidden_size) # 15. 返回缓存 new_kv_cache = (key_states, value_states) present_state = (state_updated, all_state_tokens, all_position_ids, new_kv_cache) return hidden_output, state_updated, present_state def _forward_prefill( self, hidden_states: torch.Tensor, state: torch.Tensor, position_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor], use_cache: bool, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple]]]: """ Prefill 模式:处理完整序列 返回的 present_state: (state_last, state_tokens, position_ids, kv_cache) """ batch_size, seq_len, _ = hidden_states.shape # 0. State 非线性更新 state = self.state_cell(state) # 1. 为每个位置创建因果 state if position_ids is None: raise ValueError("position_ids is required for inverse rotation") state_shifted = torch.cat([ torch.zeros_like(state[:, :1, :]), state[:, :-1, :] ], dim=1) # 2. 逆旋转 state_relative = self.holographic_rope.apply_inverse_rotation(state_shifted, position_ids) # 3. 投影为 state tokens state_tokens = self.state_projection(state_relative) if state_tokens.shape[1] > seq_len: state_tokens = state_tokens[:, -seq_len:, :] # 4. 拼接 combined_input = torch.cat([state_tokens, hidden_states], dim=1) # 5. position_ids if position_ids.shape[1] > seq_len: current_position_ids = position_ids[:, -seq_len:] else: current_position_ids = position_ids extended_position_ids = torch.cat([current_position_ids, current_position_ids], dim=1) # 6. 构建因果 attention mask(原始设计) extended_mask = torch.full( (batch_size, 1, 2 * seq_len, 2 * seq_len), torch.finfo(combined_input.dtype).min, dtype=combined_input.dtype, device=combined_input.device ) for i in range(seq_len): extended_mask[:, :, i, :i+1] = 0 for i in range(seq_len): extended_mask[:, :, seq_len + i, :i+1] = 0 if i > 0: extended_mask[:, :, seq_len + i, seq_len:seq_len+i] = 0 # 7. Self-Attention (手动计算以获取 KV cache) residual = combined_input hidden_states_normed = self.input_layernorm(combined_input) position_embeddings = self.rotary_emb(hidden_states_normed, extended_position_ids) # 手动 QKV 投影 head_dim = self.self_attn.head_dim num_heads = self.self_attn.config.num_attention_heads num_kv_heads = self.self_attn.config.num_key_value_heads num_kv_groups = num_heads // num_kv_heads query_states = self.self_attn.q_proj(hidden_states_normed) key_states = self.self_attn.k_proj(hidden_states_normed) value_states = self.self_attn.v_proj(hidden_states_normed) # Reshape query_states = query_states.view(batch_size, 2 * seq_len, num_heads, head_dim).transpose(1, 2) key_states = key_states.view(batch_size, 2 * seq_len, num_kv_heads, head_dim).transpose(1, 2) value_states = value_states.view(batch_size, 2 * seq_len, num_kv_heads, head_dim).transpose(1, 2) # RoPE cos, sin = position_embeddings from transformers.models.llama.modeling_llama import apply_rotary_pos_emb query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Attention key_states_expanded = key_states.repeat_interleave(num_kv_groups, dim=1) value_states_expanded = value_states.repeat_interleave(num_kv_groups, dim=1) attn_weights = torch.matmul(query_states, key_states_expanded.transpose(2, 3)) * (head_dim ** -0.5) attn_weights = attn_weights + extended_mask attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states_expanded) # Output projection attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, 2 * seq_len, -1) attn_output = self.self_attn.o_proj(attn_output) combined_output = residual + attn_output # 8. MLP residual = combined_output combined_output = self.post_attention_layernorm(combined_output) combined_output = self.mlp(combined_output) combined_output = residual + combined_output # 9. 提取 hidden states (output) hidden_states_output = combined_output[:, seq_len:, :] # 10. 准备缓存:(state_last, state_tokens, position_ids, kv_cache) if use_cache: kv_cache = (key_states, value_states) present_state = (state[:, -1:, :], state_tokens, current_position_ids, kv_cache) else: present_state = None return hidden_states_output, state, present_state # ============================================================================ # 主模型 # ============================================================================ class NanoHammerPreTrainedModel(PreTrainedModel): config_class = NanoHammerConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["HybridNanoHammerDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _is_stateful = True # 告诉 HuggingFace 我们使用自定义缓存 @classmethod def _supports_default_dynamic_cache(cls) -> bool: """ 返回 False 以禁用 HuggingFace 的 DynamicCache 系统。 NanoHammer 使用自定义的 state token 缓存结构。 """ return False def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() class NanoHammerModel(NanoHammerPreTrainedModel): """ NanoHammer 主模型 核心架构: 1. Token Embedding + State 初始化(全息积分) 2. N × HybridDecoderLayer(State Update + Self-Attn + Cross-Attn + MLP) 3. Final Norm 增量生成 (O(1) per step): - 使用 past_state_absolute 缓存每层的累积状态 - 每个生成步只处理一个 token + 一个 state token - 状态缓存大小固定为 O(num_layers * state_hidden_size) """ def __init__(self, config: NanoHammerConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.num_hidden_layers = config.num_hidden_layers self.state_hidden_size = int(config.state_hidden_size) # Token Embedding self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) # Token -> State 投影 self.token_to_state = nn.Linear(config.hidden_size, self.state_hidden_size, bias=False) # 全息旋转位置编码 self.holographic_rope = HolographicRotaryEmbedding( self.state_hidden_size, max_position_embeddings=config.max_position_embeddings ) # Decoder Layers self.layers = nn.ModuleList([ HybridNanoHammerDecoderLayer(config, layer_idx, self.holographic_rope) for layer_idx in range(config.num_hidden_layers) ]) # Final Norm self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False self.post_init() def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_state_absolute: Optional[Tuple] = None, # 新缓存结构 inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs ) -> Union[Tuple, BaseModelOutputWithPast]: """ Args: past_state_absolute: 元组 (cumsum_cache, layer_caches) - cumsum_cache: (B, 1, state_hidden_size) - layer_caches: list of (state_last, state_tokens, position_ids) per layer """ output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # 1. Token Embedding if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) batch_size, seq_len, _ = inputs_embeds.shape device = inputs_embeds.device # 检测增量模式 is_incremental = past_state_absolute is not None and seq_len == 1 # 2. Position IDs if position_ids is None: if is_incremental: raise ValueError("position_ids must be provided in incremental mode") full_position_ids = torch.arange(0, seq_len, dtype=torch.long, device=device) full_position_ids = full_position_ids.unsqueeze(0).expand(batch_size, -1) else: full_position_ids = position_ids # 3. Token -> State 投影 + 全息旋转 state_input = self.token_to_state(inputs_embeds) state_input_rotated = self.holographic_rope(state_input, full_position_ids) # 4. 根据模式处理 if is_incremental: # ============ 增量生成模式 ============ cumsum_cache, layer_caches = past_state_absolute # 更新 cumsum new_cumsum = cumsum_cache + state_input_rotated state = new_cumsum hidden_states = inputs_embeds all_hidden_states = () if output_hidden_states else None new_layer_caches = [] for layer_idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) past_state_layer = layer_caches[layer_idx] hidden_states, state, present_state = decoder_layer( hidden_states, state, full_position_ids, attention_mask=None, past_state=past_state_layer, use_cache=True, ) new_layer_caches.append(present_state) new_past_state_absolute = (new_cumsum, new_layer_caches) else: # ============ Prefill 模式 ============ cumsum_state = torch.cumsum(state_input_rotated, dim=1) state = cumsum_state hidden_states = inputs_embeds all_hidden_states = () if output_hidden_states else None layer_caches = [] # attention mask if attention_mask is None: causal_attention_mask = torch.zeros( (batch_size, 1, seq_len, seq_len), dtype=inputs_embeds.dtype, device=device ) causal_mask = torch.triu(torch.ones((seq_len, seq_len), device=device), diagonal=1).bool() causal_attention_mask[:, :, causal_mask] = torch.finfo(inputs_embeds.dtype).min elif attention_mask.dim() == 2: expanded_mask = torch.zeros( (batch_size, 1, seq_len, seq_len), dtype=inputs_embeds.dtype, device=device ) causal_mask = torch.triu(torch.ones((seq_len, seq_len), device=device), diagonal=1).bool() expanded_mask[:, :, causal_mask] = torch.finfo(inputs_embeds.dtype).min for b in range(batch_size): for i in range(seq_len): if attention_mask[b, i] == 0: expanded_mask[b, 0, :, i] = torch.finfo(inputs_embeds.dtype).min causal_attention_mask = expanded_mask else: causal_attention_mask = attention_mask for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: hidden_states, state, present_state = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, state, full_position_ids, causal_attention_mask, None, use_cache, ) else: hidden_states, state, present_state = decoder_layer( hidden_states, state, full_position_ids, causal_attention_mask, past_state=None, use_cache=use_cache, ) if use_cache and present_state is not None: layer_caches.append(present_state) # 准备缓存 if use_cache and layer_caches: cumsum_cache = cumsum_state[:, -1:, :] new_past_state_absolute = (cumsum_cache, layer_caches) else: new_past_state_absolute = None # 5. Final Norm hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) # 6. 返回 if not return_dict: outputs = (hidden_states,) if output_hidden_states: outputs += (all_hidden_states,) if use_cache: outputs += (new_past_state_absolute,) return outputs return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=new_past_state_absolute if use_cache else None, # 使用 past_key_values 存储缓存 hidden_states=all_hidden_states, attentions=None, ) class NanoHammerForCausalLM(NanoHammerPreTrainedModel): """ NanoHammer 语言模型 增量生成:保存历史 state_tokens,复杂度 O(T × hidden_size) 比传统 KV Cache O(T × 2 × hidden_size × num_layers) 节省 ~2×num_layers 倍内存 """ _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) self.model = NanoHammerModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple] = None, # 使用标准接口 inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs ) -> Union[Tuple, CausalLMOutputWithPast]: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_state_absolute=past_key_values, # 传递给 model inputs_embeds=inputs_embeds, use_cache=use_cache, output_hidden_states=output_hidden_states, return_dict=True, ) hidden_states = outputs.last_hidden_state logits = self.lm_head(hidden_states) logits = logits.float() loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + (outputs.hidden_states,) if use_cache: output = output + (outputs.past_key_values,) return ((loss,) + output) if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, # 使用标准字段 hidden_states=outputs.hidden_states, attentions=None, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): """准备生成输入""" position_ids = kwargs.get("position_ids", None) if past_key_values is not None: # 增量模式:只取最后一个 token input_ids = input_ids[:, -1:] if position_ids is not None: position_ids = position_ids[:, -1:] elif attention_mask is not None: position_ids = attention_mask.long().cumsum(-1)[:, -1:] - 1 else: raise ValueError("position_ids or attention_mask required in incremental mode") inputs_embeds = None else: # Prefill 模式 if position_ids is None and attention_mask is not None: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if inputs_embeds is not None and input_ids.shape[1] != inputs_embeds.shape[1]: inputs_embeds = None return { "input_ids": input_ids if inputs_embeds is None else None, "inputs_embeds": inputs_embeds, "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache", True), "attention_mask": attention_mask, } def _update_model_kwargs_for_generation( self, outputs, model_kwargs, is_encoder_decoder=False, num_new_tokens=1, ): """更新 model_kwargs""" # 更新 past_key_values model_kwargs["past_key_values"] = outputs.past_key_values # 扩展 attention_mask if "attention_mask" in model_kwargs: attention_mask = model_kwargs["attention_mask"] model_kwargs["attention_mask"] = torch.cat( [attention_mask, attention_mask.new_ones((attention_mask.shape[0], num_new_tokens))], dim=-1, ) return model_kwargs